"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ prefix_token (`str`, *optional*, defaults to `"▁"`):
+ Prefix token used for infilling.
+ middle_token (`str`, *optional*, defaults to `"▁"`):
+ Middle token used for infilling.
+ suffix_token (`str`, *optional*, defaults to `"▁"`):
+ Suffix token used for infilling.
+ eot_token (`str`, *optional*, defaults to `"▁"`):
+ End of text token used for infilling.
+ fill_token (`str`, *optional*, defaults to `""`):
+ The token used to split the input between the prefix and suffix.
+ suffix_first (`bool`, *optional*, defaults to `False`):
+ Whether the input prompt and suffix should be formatted with the suffix first.
+ sp_model_kwargs (`dict`, *optional*):
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+ to set:
+
+ - `enable_sampling`: Enable subword regularization.
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+ - `nbest_size = {0,1}`: No sampling is performed.
+ - `nbest_size > 1`: samples from the nbest_size results.
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+ using forward-filtering-and-backward-sampling algorithm.
+
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+ BPE-dropout.
+ add_bos_token (`bool`, *optional*, defaults to `True`):
+ Whether to add a beginning of sequence token at the start of sequences.
+ add_eos_token (`bool`, *optional*, defaults to `False`):
+ Whether to add an end of sequence token at the end of sequences.
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+ Whether or not to clean up the tokenization spaces.
+ additional_special_tokens (`list[str]`, *optional*):
+ Additional special tokens used by the tokenizer.
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
+ Whether or not the default system prompt for Llama should be used.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ unk_token="",
+ bos_token="",
+ eos_token="",
+ prefix_token="▁",
+ middle_token="▁",
+ suffix_token="▁",
+ eot_token="▁",
+ fill_token="",
+ suffix_first=False,
+ sp_model_kwargs: Optional[dict[str, Any]] = None,
+ add_bos_token=True,
+ add_eos_token=False,
+ clean_up_tokenization_spaces=False,
+ additional_special_tokens=None,
+ use_default_system_prompt=False,
+ **kwargs,
+ ):
+ requires_backends(self, "protobuf")
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
+
+ self.use_default_system_prompt = use_default_system_prompt
+ # mark tokens special to skip them
+ additional_special_tokens = additional_special_tokens or []
+ for token in [prefix_token, middle_token, suffix_token, eot_token]:
+ additional_special_tokens += [token] if token is not None else []
+
+ self.vocab_file = vocab_file
+ self.add_bos_token = add_bos_token
+ self.add_eos_token = add_eos_token
+ self._prefix_token = prefix_token
+ self._middle_token = middle_token
+ self._suffix_token = suffix_token
+ self._eot_token = eot_token
+ self.fill_token = fill_token
+ self.suffix_first = suffix_first
+ self.sp_model = self.get_spm_processor()
+
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
+ prefix_token=prefix_token,
+ middle_token=middle_token,
+ suffix_token=suffix_token,
+ eot_token=eot_token,
+ fill_token=fill_token,
+ sp_model_kwargs=self.sp_model_kwargs,
+ suffix_first=suffix_first,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ additional_special_tokens=additional_special_tokens,
+ use_default_system_prompt=use_default_system_prompt,
+ **kwargs,
+ )
+
+ @property
+ def unk_token_length(self):
+ return len(self.sp_model.encode(str(self.unk_token)))
+
+ def get_spm_processor(self):
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ with open(self.vocab_file, "rb") as f:
+ sp_model = f.read()
+ model_pb2 = import_protobuf()
+ model = model_pb2.ModelProto.FromString(sp_model)
+ normalizer_spec = model_pb2.NormalizerSpec()
+ normalizer_spec.add_dummy_prefix = False
+ model.normalizer_spec.MergeFrom(normalizer_spec)
+ sp_model = model.SerializeToString()
+ tokenizer.LoadFromSerializedProto(sp_model)
+ return tokenizer
+
+ @property
+ def prefix_token(self):
+ return self._prefix_token
+
+ @property
+ def prefix_id(self):
+ if self._prefix_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.prefix_token)
+
+ @property
+ def middle_token(self):
+ return self._middle_token
+
+ @property
+ def middle_id(self):
+ if self._middle_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.middle_token)
+
+ @property
+ def suffix_token(self):
+ return self._suffix_token
+
+ @property
+ def suffix_id(self):
+ if self._suffix_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.suffix_token)
+
+ @property
+ def eot_token(self):
+ return self._eot_token
+
+ @property
+ def eot_id(self):
+ if self._eot_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.eot_token)
+
+ @property
+ def vocab_size(self):
+ """Returns vocab size"""
+ return self.sp_model.get_piece_size()
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab
+ def get_vocab(self):
+ """Returns vocab as a dict"""
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def tokenize(self, prefix, suffix=None, suffix_first=False, **kwargs) -> list[int]:
+ # add a prefix space to `prefix`
+ if self.fill_token is not None and self.fill_token in prefix and suffix is None:
+ prefix, suffix = prefix.split(self.fill_token)
+
+ if len(prefix) > 0:
+ prefix = SPIECE_UNDERLINE + prefix.replace(SPIECE_UNDERLINE, " ")
+
+ if suffix is None or len(suffix) < 1:
+ tokens = super().tokenize(prefix, **kwargs)
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
+ tokens = tokens[1:]
+ return tokens
+
+ prefix_tokens = self._tokenize(prefix) # prefix has an extra `SPIECE_UNDERLINE`
+
+ if None in (self.prefix_id, self.middle_id, self.suffix_id):
+ raise ValueError(
+ "The input either includes a `prefix` and a `suffix` used for the infilling task,"
+ f" or can be split on the {self.fill_token} token, creating a suffix and prefix,"
+ " but the model does not support `infilling`."
+ )
+ suffix_tokens = self._tokenize(suffix) # make sure CodeLlama sp model does not mess up
+
+ suffix_first = suffix_first if suffix_first is not None else self.suffix_first
+ if suffix_first:
+ # format as " {suf} {pre}"
+ return [self.prefix_token, self.suffix_token] + suffix_tokens + [self.middle_token] + prefix_tokens
+ else:
+ # format as " {pre} {suf} "
+ return [self.prefix_token] + prefix_tokens + [self.suffix_token] + suffix_tokens + [self.middle_token]
+
+ def _tokenize(self, text, **kwargs):
+ """
+ Returns a tokenized string.
+
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
+ SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
+ `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
+ `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`.
+ `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`.
+ """
+ tokens = self.sp_model.encode(text, out_type=str)
+ if not text.startswith((SPIECE_UNDERLINE, " ")):
+ return tokens
+ # 1. Encode string + prefix ex: " Hey"
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.sp_model.piece_to_id(token)
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ token = self.sp_model.IdToPiece(index)
+ return token
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ # since we manually add the prefix space, we have to remove it when decoding
+ if tokens[0].startswith(SPIECE_UNDERLINE):
+ tokens[0] = tokens[0][1:]
+
+ current_sub_tokens = []
+ out_string = ""
+ for _, token in enumerate(tokens):
+ # make sure that special tokens are not decoded using sentencepiece model
+ if token in self.all_special_tokens:
+ out_string += self.sp_model.decode(current_sub_tokens) + token
+ current_sub_tokens = []
+ else:
+ current_sub_tokens.append(token)
+ out_string += self.sp_model.decode(current_sub_tokens)
+ return out_string
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> tuple[str]:
+ """
+ Save the vocabulary and special tokens file to a directory.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+ output = bos_token_id + token_ids_0 + eos_token_id
+
+ if token_ids_1 is not None:
+ output = output + bos_token_id + token_ids_1 + eos_token_id
+
+ return output
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ bos_token_id = [1] if self.add_bos_token else []
+ eos_token_id = [1] if self.add_eos_token else []
+
+ if token_ids_1 is None:
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
+ return (
+ bos_token_id
+ + ([0] * len(token_ids_0))
+ + eos_token_id
+ + bos_token_id
+ + ([0] * len(token_ids_1))
+ + eos_token_id
+ )
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
+ sequence pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of ids.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
+
+ if token_ids_1 is not None:
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
+
+ return output
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
+
+
+__all__ = ["CodeLlamaTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3978587e7f02512a5344f9ad0a33bf86b839757
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py
@@ -0,0 +1,374 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from shutil import copyfile
+from typing import Optional
+
+from tokenizers import normalizers, processors
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import is_sentencepiece_available, logging
+
+
+if is_sentencepiece_available():
+ from .tokenization_code_llama import CodeLlamaTokenizer
+else:
+ CodeLlamaTokenizer = None
+
+logger = logging.get_logger(__name__)
+VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
+
+SPIECE_UNDERLINE = "▁"
+
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<>\n", "\n<>\n\n"
+
+# fmt: off
+DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
+answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
+ that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
+correct. If you don't know the answer to a question, please don't share false information."""
+# fmt: on
+
+
+class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ This uses notably ByteFallback and no normalization.
+
+ ```python
+ >>> from transformers import CodeLlamaTokenizerFast
+
+ >>> tokenizer = CodeLlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
+ >>> tokenizer.encode("Hello this is a test")
+ [1, 15043, 445, 338, 263, 1243]
+ ```
+
+ If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
+ call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
+ values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
+ [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
+
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods. The default configuration match that of
+ [meta-llama/CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json)
+ which supports prompt infilling.
+
+ Args:
+ vocab_file (`str`, *optional*):
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ tokenizer_file (`str`, *optional*):
+ [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
+ contains everything needed to load the tokenizer.
+ clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
+ Whether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
+ spaces.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+ prefix_token (`str`, *optional*, defaults to `"▁"`):
+ Prefix token used for infilling.
+ middle_token (`str`, *optional*, defaults to `"▁"`):
+ Middle token used for infilling.
+ suffix_token (`str`, *optional*, defaults to `"▁"`):
+ Suffix token used for infilling.
+ eot_token (`str`, *optional*, defaults to `"▁"`):
+ End of text token used for infilling.
+ fill_token (`str`, *optional*, defaults to `""`):
+ The token used to split the input between the prefix and suffix.
+ additional_special_tokens (`list[str]`, *optional*):
+ Additional special tokens used by the tokenizer.
+ add_bos_token (`bool`, *optional*, defaults to `True`):
+ Whether to add a beginning of sequence token at the start of sequences.
+ add_eos_token (`bool`, *optional*, defaults to `False`):
+ Whether to add an end of sequence token at the end of sequences.
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
+ Whether or not the default system prompt for Llama should be used.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = CodeLlamaTokenizer
+ padding_side = "left"
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ clean_up_tokenization_spaces=False,
+ unk_token="",
+ bos_token="",
+ eos_token="",
+ prefix_token="▁",
+ middle_token="▁",
+ suffix_token="▁",
+ eot_token="▁",
+ fill_token="",
+ additional_special_tokens=None,
+ add_bos_token=True,
+ add_eos_token=False,
+ use_default_system_prompt=False,
+ **kwargs,
+ ):
+ # mark tokens special to skip them
+ additional_special_tokens = additional_special_tokens or []
+ for token in [prefix_token, middle_token, suffix_token, eot_token]:
+ additional_special_tokens += [token] if token is not None else []
+ self.use_default_system_prompt = use_default_system_prompt
+
+ super().__init__(
+ vocab_file=vocab_file,
+ tokenizer_file=tokenizer_file,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ additional_special_tokens=additional_special_tokens,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
+ prefix_token=prefix_token,
+ middle_token=middle_token,
+ suffix_token=suffix_token,
+ eot_token=eot_token,
+ fill_token=fill_token,
+ use_default_system_prompt=use_default_system_prompt,
+ **kwargs,
+ )
+ self._add_bos_token = add_bos_token
+ self._add_eos_token = add_eos_token
+ self.update_post_processor()
+
+ self.vocab_file = vocab_file
+
+ self._prefix_token = prefix_token
+ self._middle_token = middle_token
+ self._suffix_token = suffix_token
+ self._eot_token = eot_token
+ self.fill_token = fill_token
+
+ # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor
+ def update_post_processor(self):
+ """
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
+ """
+ bos = self.bos_token
+ bos_token_id = self.bos_token_id
+ if bos is None and self.add_bos_token:
+ raise ValueError("add_bos_token = True but bos_token = None")
+
+ eos = self.eos_token
+ eos_token_id = self.eos_token_id
+ if eos is None and self.add_eos_token:
+ raise ValueError("add_eos_token = True but eos_token = None")
+
+ single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
+ pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
+
+ special_tokens = []
+ if self.add_bos_token:
+ special_tokens.append((bos, bos_token_id))
+ if self.add_eos_token:
+ special_tokens.append((eos, eos_token_id))
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=single, pair=pair, special_tokens=special_tokens
+ )
+
+ @property
+ def prefix_token(self):
+ return self._prefix_token
+
+ @property
+ def prefix_id(self):
+ if self._prefix_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.prefix_token)
+
+ @property
+ def middle_token(self):
+ return self._middle_token
+
+ @property
+ def middle_id(self):
+ if self._middle_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.middle_token)
+
+ @property
+ def suffix_token(self):
+ return self._suffix_token
+
+ @property
+ def suffix_id(self):
+ if self._suffix_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.suffix_token)
+
+ @property
+ def eot_id(self):
+ if self._eot_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.eot_token)
+
+ @property
+ def eot_token(self):
+ return self._eot_token
+
+ @property
+ def add_eos_token(self):
+ return self._add_eos_token
+
+ @property
+ def add_bos_token(self):
+ return self._add_bos_token
+
+ @add_eos_token.setter
+ def add_eos_token(self, value):
+ self._add_eos_token = value
+ self.update_post_processor()
+
+ @add_bos_token.setter
+ def add_bos_token(self, value):
+ self._add_bos_token = value
+ self.update_post_processor()
+
+ def set_infilling_processor(self, reset, suffix_first=False, add_special_tokens=True):
+ """
+ Updates the normalizer to make sure the prompt format for `infilling` is respected. The infilling format is the
+ following: if suffix_first
+ " {suf} {pre}"
+ else:
+ " {pre} {suf} "
+
+ If `reset` is set to `True`, the `normalizer` and `post_processor` are reset to their "normal" behaviour, which
+ is to add a prefix space for the normalizer, and add a `bos_token` to the input text for the `post_processor`.
+ """
+ if reset:
+ self._tokenizer.normalizer = normalizers.Sequence(
+ [
+ normalizers.Prepend(prepend="▁"),
+ normalizers.Replace(pattern=" ", content="▁"),
+ ]
+ )
+ self.update_post_processor()
+ return
+
+ self._tokenizer.normalizer = normalizers.Replace(pattern=" ", content="▁")
+ pair = [self.bos_token] if self.add_bos_token and add_special_tokens else []
+ special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else []
+ if suffix_first:
+ # format as " {suf} {pre}"
+ pair += [self.prefix_token, self.suffix_token, "$B", self.middle_token, "$A"]
+ special_tokens += [
+ (self.prefix_token, self.prefix_id),
+ (self.suffix_token, self.suffix_id),
+ (self.middle_token, self.middle_id),
+ ]
+ else:
+ # format as " {pre} {suf} "
+ pair += [self.prefix_token, "$A", self.suffix_token, "$B", self.middle_token]
+ special_tokens += [
+ (self.prefix_token, self.prefix_id),
+ (self.suffix_token, self.suffix_id),
+ (self.middle_token, self.middle_id),
+ ]
+
+ if self.add_eos_token and add_special_tokens:
+ pair += [self.eos_token]
+ special_tokens += [(self.eos_token, self.eos_token_id)]
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single="$A", pair=pair, special_tokens=special_tokens
+ )
+
+ def encode_plus(self, text, text_pair=None, suffix_first=False, add_special_tokens=True, **kwargs):
+ # hack to make sure the input is pre-process but outside rust
+ text_pair = kwargs.pop("suffix", text_pair)
+ if self.fill_token is not None and self.fill_token in text and text_pair is None:
+ text, text_pair = text.split(self.fill_token)
+
+ if text_pair is None or len(text_pair) < 1:
+ return super().encode_plus(text, text_pair, add_special_tokens=add_special_tokens, **kwargs)
+
+ if None in (self.prefix_id, self.middle_id, self.suffix_id):
+ raise ValueError(
+ "Then input includes a `prefix` and a `suffix` used for the infilling task,"
+ " the `prefix_id, middle_id, suffix_id` must all be initialized. Current"
+ f" values : {self.prefix_id, self.middle_id, self.suffix_id}"
+ )
+
+ self.set_infilling_processor(False, suffix_first=suffix_first, add_special_tokens=add_special_tokens)
+ tokens = super().encode_plus(" " + text, text_pair=text_pair, add_special_tokens=True, **kwargs)
+ self.set_infilling_processor(True)
+ return tokens
+
+ # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.save_vocabulary
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not self.can_save_slow_tokenizer:
+ raise ValueError(
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+ "tokenizer."
+ )
+
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. The special tokens depend on calling set_lang.
+
+ An NLLB sequence has the following format, where `X` represents the sequence:
+
+ - `input_ids` (for encoder) `X [eos, src_lang_code]`
+ - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
+
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+ separator.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return self.bos_token_id + token_ids_0 + self.eos_token_id
+ return self.bos_token_id + token_ids_0 + token_ids_1 + self.eos_token_id
+
+
+__all__ = ["CodeLlamaTokenizerFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea2d9af11150f556b2cedfb78271f174256e64b0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_codegen import *
+ from .modeling_codegen import *
+ from .tokenization_codegen import *
+ from .tokenization_codegen_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/configuration_codegen.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/configuration_codegen.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a9ab842710cdd67911305ce324fe4c68dec173b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/configuration_codegen.py
@@ -0,0 +1,231 @@
+# coding=utf-8
+# Copyright 2022 Salesforce authors, The EleutherAI, and HuggingFace Teams. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""CodeGen model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+from typing import Any, Optional
+
+from ... import PreTrainedTokenizer, TensorType, is_torch_available
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfigWithPast, PatchingSpec
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class CodeGenConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`CodeGenModel`]. It is used to instantiate a
+ CodeGen model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the CodeGen
+ [Salesforce/codegen-2B-mono](https://huggingface.co/Salesforce/codegen-2B-mono) architecture. Configuration objects
+ inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from
+ [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50400):
+ Vocabulary size of the CodeGen model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`CodeGenModel`].
+ n_positions (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ n_ctx (`int`, *optional*, defaults to 2048):
+ This attribute is used in `CodeGenModel.__init__` without any real effect.
+ n_embd (`int`, *optional*, defaults to 4096):
+ Dimensionality of the embeddings and hidden states.
+ n_layer (`int`, *optional*, defaults to 28):
+ Number of hidden layers in the Transformer encoder.
+ n_head (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ rotary_dim (`int`, *optional*, defaults to 64):
+ Number of dimensions in the embedding that Rotary Position Embedding is applied to.
+ n_inner (`int`, *optional*):
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
+ activation_function (`str`, *optional*, defaults to `"gelu_new"`):
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
+ The dropout ratio for the embeddings.
+ attn_pdrop (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
+ The epsilon to use in the layer normalization layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ bos_token_id (`int`, *optional*, defaults to 50256):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 50256):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
+ model has a output word embedding layer.
+
+ Example:
+
+ ```python
+ >>> from transformers import CodeGenConfig, CodeGenModel
+
+ >>> # Initializing a CodeGen 6B configuration
+ >>> configuration = CodeGenConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = CodeGenModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "codegen"
+ attribute_map = {
+ "max_position_embeddings": "n_positions",
+ "hidden_size": "n_embd",
+ "num_attention_heads": "n_head",
+ "num_hidden_layers": "n_layer",
+ }
+
+ def __init__(
+ self,
+ vocab_size=50400,
+ n_positions=2048,
+ n_ctx=2048,
+ n_embd=4096,
+ n_layer=28,
+ n_head=16,
+ rotary_dim=64,
+ n_inner=None,
+ activation_function="gelu_new",
+ resid_pdrop=0.0,
+ embd_pdrop=0.0,
+ attn_pdrop=0.0,
+ layer_norm_epsilon=1e-5,
+ initializer_range=0.02,
+ use_cache=True,
+ bos_token_id=50256,
+ eos_token_id=50256,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.n_ctx = n_ctx
+ self.n_positions = n_positions
+ self.n_embd = n_embd
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.n_inner = n_inner
+ self.rotary_dim = rotary_dim
+ self.activation_function = activation_function
+ self.resid_pdrop = resid_pdrop
+ self.embd_pdrop = embd_pdrop
+ self.attn_pdrop = attn_pdrop
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_range = initializer_range
+ self.use_cache = use_cache
+
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+
+ super().__init__(
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
+ )
+
+
+# Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig
+class CodeGenOnnxConfig(OnnxConfigWithPast):
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ task: str = "default",
+ patching_specs: Optional[list[PatchingSpec]] = None,
+ use_past: bool = False,
+ ):
+ super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
+ if not getattr(self._config, "pad_token_id", None):
+ # TODO: how to do that better?
+ self._config.pad_token_id = 0
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
+ if self.use_past:
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
+ common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
+ else:
+ common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
+
+ return common_inputs
+
+ @property
+ def num_layers(self) -> int:
+ return self._config.n_layer
+
+ @property
+ def num_attention_heads(self) -> int:
+ return self._config.n_head
+
+ def generate_dummy_inputs(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ batch_size: int = -1,
+ seq_length: int = -1,
+ is_pair: bool = False,
+ framework: Optional[TensorType] = None,
+ ) -> Mapping[str, Any]:
+ common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
+ tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
+ )
+
+ # We need to order the input in the way they appears in the forward()
+ ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
+
+ # Need to add the past_keys
+ if self.use_past:
+ if not is_torch_available():
+ raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
+ else:
+ import torch
+
+ batch, seqlen = common_inputs["input_ids"].shape
+ # Not using the same length for past_key_values
+ past_key_values_length = seqlen + 2
+ past_shape = (
+ batch,
+ self.num_attention_heads,
+ past_key_values_length,
+ self._config.hidden_size // self.num_attention_heads,
+ )
+ ordered_inputs["past_key_values"] = [
+ (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
+ ]
+
+ ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
+ if self.use_past:
+ mask_dtype = ordered_inputs["attention_mask"].dtype
+ ordered_inputs["attention_mask"] = torch.cat(
+ [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
+ )
+
+ return ordered_inputs
+
+ @property
+ def default_onnx_opset(self) -> int:
+ return 13
+
+
+__all__ = ["CodeGenConfig", "CodeGenOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/modeling_codegen.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/modeling_codegen.py
new file mode 100644
index 0000000000000000000000000000000000000000..887b400b479929c947ed2dfeadc436294033ad8a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/modeling_codegen.py
@@ -0,0 +1,668 @@
+# coding=utf-8
+# Copyright 2022 Salesforce authors, The EleutherAI, and HuggingFace Teams. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch CodeGen model."""
+
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ auto_docstring,
+ is_torch_flex_attn_available,
+ logging,
+)
+from .configuration_codegen import CodeGenConfig
+
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
+def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
+ sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
+ return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
+
+
+# Copied from transformers.models.gptj.modeling_gptj.rotate_every_two
+def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
+ x1 = x[:, :, :, ::2]
+ x2 = x[:, :, :, 1::2]
+ x = torch.stack((-x2, x1), dim=-1)
+ return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
+
+
+# Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
+def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
+ sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
+ cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
+ return (tensor * cos) + (rotate_every_two(tensor) * sin)
+
+
+class CodeGenAttention(nn.Module):
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+
+ max_positions = config.max_position_embeddings
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.embed_dim = config.hidden_size
+ self.num_attention_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_attention_heads
+ if self.head_dim * self.num_attention_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
+ f" `num_attention_heads`: {self.num_attention_heads})."
+ )
+ self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
+ self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
+
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
+ self.rotary_dim = config.rotary_dim
+ pos_embd_dim = self.rotary_dim or self.embed_dim
+ self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
+
+ def _split_heads(self, x, n_head, dim_head, mp_num):
+ reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
+ reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
+ return reshaped
+
+ def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
+ """
+ Merges attn_head_size dim and num_attn_heads dim into n_ctx
+ """
+ if len(tensor.shape) == 5:
+ tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
+ elif len(tensor.shape) == 4:
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ else:
+ raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
+ new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
+ return tensor.view(new_shape)
+
+ def _attn(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=None,
+ head_mask=None,
+ ):
+ # Keep the attention weights computation in fp32 to avoid overflow issues
+ query = query.to(torch.float32)
+ key = key.to(torch.float32)
+
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
+
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
+ attn_weights += causal_mask
+
+ attn_weights = attn_weights / self.scale_attn
+ attn_weights = nn.Softmax(dim=-1)(attn_weights)
+ attn_weights = attn_weights.to(value.dtype)
+ attn_weights = self.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+
+ return attn_output, attn_weights
+
+ def forward(
+ self,
+ hidden_states: Optional[torch.FloatTensor],
+ layer_past: Optional[Cache] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[
+ tuple[torch.Tensor, tuple[torch.Tensor]],
+ Optional[tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]]],
+ ]:
+ qkv = self.qkv_proj(hidden_states)
+ # TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
+ mp_num = 4
+ qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
+
+ local_dim = self.head_dim * self.num_attention_heads // mp_num
+ query, value, key = torch.split(qkv_split, local_dim, dim=-1)
+ query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
+ key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
+
+ value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
+ value = value.permute(0, 2, 1, 3)
+
+ embed_positions = self.embed_positions
+ if embed_positions.device != position_ids.device:
+ embed_positions = embed_positions.to(position_ids.device)
+ self.embed_positions = embed_positions
+
+ sincos = embed_positions[position_ids]
+ sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
+
+ if self.rotary_dim is not None:
+ k_rot = key[:, :, :, : self.rotary_dim]
+ k_pass = key[:, :, :, self.rotary_dim :]
+
+ q_rot = query[:, :, :, : self.rotary_dim]
+ q_pass = query[:, :, :, self.rotary_dim :]
+
+ k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
+ q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
+
+ key = torch.cat([k_rot, k_pass], dim=-1)
+ query = torch.cat([q_rot, q_pass], dim=-1)
+ else:
+ key = apply_rotary_pos_emb(key, sin, cos)
+ query = apply_rotary_pos_emb(query, sin, cos)
+
+ key = key.permute(0, 2, 1, 3)
+ query = query.permute(0, 2, 1, 3)
+
+ # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
+ # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
+ if layer_past is not None:
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "partial_rotation_size": self.rotary_dim,
+ "cache_position": cache_position,
+ }
+ key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs)
+
+ # compute self-attention: V x Softmax(QK^T)
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
+
+ attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
+ attn_output = self.out_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.gptj.modeling_gptj.GPTJMLP with GPTJ->CodeGen
+class CodeGenMLP(nn.Module):
+ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
+ super().__init__()
+ embed_dim = config.n_embd
+
+ self.fc_in = nn.Linear(embed_dim, intermediate_size)
+ self.fc_out = nn.Linear(intermediate_size, embed_dim)
+
+ self.act = ACT2FN[config.activation_function]
+ self.dropout = nn.Dropout(config.resid_pdrop)
+
+ def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor:
+ hidden_states = self.fc_in(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.fc_out(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
+class CodeGenBlock(GradientCheckpointingLayer):
+ # Ignore copy
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
+ self.attn = CodeGenAttention(config, layer_idx)
+ self.mlp = CodeGenMLP(inner_dim, config)
+
+ def forward(
+ self,
+ hidden_states: Optional[torch.FloatTensor],
+ layer_past: Optional[Cache] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]:
+ residual = hidden_states
+ hidden_states = self.ln_1(hidden_states)
+ attn_outputs, attn_weights = self.attn(
+ hidden_states=hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ feed_forward_hidden_states = self.mlp(hidden_states)
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
+
+ return hidden_states, attn_weights
+
+
+@auto_docstring
+class CodeGenPreTrainedModel(PreTrainedModel):
+ config: CodeGenConfig
+ base_model_prefix = "transformer"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["CodeGenBlock"]
+ _skip_keys_device_placement = "past_key_values"
+
+ _can_compile_fullgraph = True
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, (nn.Linear,)):
+ # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class CodeGenModel(CodeGenPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embed_dim = config.n_embd
+ self.vocab_size = config.vocab_size
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
+ self.drop = nn.Dropout(config.embd_pdrop)
+ self.h = nn.ModuleList([CodeGenBlock(config, layer_idx=i) for i in range(config.n_layer)])
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+ self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
+
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.wte
+
+ def set_input_embeddings(self, new_embeddings):
+ self.wte = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs, # NOOP kwargs, for now
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ r"""
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ seq_length = inputs_embeds.shape[1]
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x num_attention_heads x N x N
+ # head_mask has shape n_layer x batch x num_attention_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+ hidden_states = inputs_embeds
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, seq_length)
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+ output_shape = (-1, seq_length, hidden_states.size(-1))
+
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, block in enumerate(self.h):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = block(
+ hidden_states,
+ layer_past=past_key_values,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[1],)
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, "BlockMask"],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+@auto_docstring(
+ custom_intro="""
+ The CodeGen Model transformer with a language modeling head on top.
+ """
+)
+class CodeGenForCausalLM(CodeGenPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = CodeGenModel(config)
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ hidden_states = transformer_outputs[0]
+
+ # make sure sampling in fp16 works correctly and
+ # compute loss in fp32 to match with mesh-tf version
+ # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
+ lm_logits = self.lm_head(hidden_states).to(torch.float32)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(lm_logits.device)
+ # Flatten the tokens
+ loss = self.loss_function(
+ lm_logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ loss = loss.to(hidden_states.dtype)
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+__all__ = ["CodeGenForCausalLM", "CodeGenModel", "CodeGenPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/tokenization_codegen.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/tokenization_codegen.py
new file mode 100644
index 0000000000000000000000000000000000000000..152b1a84fc37d5ed6613072284133565ebc86cbf
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/tokenization_codegen.py
@@ -0,0 +1,390 @@
+# coding=utf-8
+# Copyright 2022 The Salesforce authors, The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for CodeGen"""
+
+import json
+import os
+from functools import lru_cache
+from typing import TYPE_CHECKING, Optional, Union
+
+import numpy as np
+import regex as re
+
+from ...utils import is_tf_available, is_torch_available, logging, to_py_obj
+
+
+if TYPE_CHECKING:
+ if is_torch_available():
+ import torch
+ if is_tf_available():
+ import tensorflow as tf
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+ "vocab_file": "vocab.json",
+ "merges_file": "merges.txt",
+}
+
+
+@lru_cache
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+ characters the bpe code barfs on.
+
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
+ tables between utf-8 bytes and unicode strings.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+class CodeGenTokenizer(PreTrainedTokenizer):
+ """
+ Construct a CodeGen tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```python
+ >>> from transformers import CodeGenTokenizer
+
+ >>> tokenizer = CodeGenTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
+ >>> tokenizer("Hello world")["input_ids"]
+ [15496, 995]
+
+ >>> tokenizer(" Hello world")["input_ids"]
+ [18435, 995]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
+ call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The beginning of sequence token.
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The end of sequence token.
+ pad_token (`str`, *optional*):
+ The token used for padding, for example when batching sequences of different lengths.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (CodeGen tokenizer detect beginning of words by the preceding space).
+ add_bos_token (`bool`, *optional*, defaults to `False`):
+ Whether to add a beginning of sequence token at the start of sequences.
+ return_token_type_ids (`bool`, *optional*, defaults to `False`):
+ Whether to return token type IDs.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ errors="replace",
+ unk_token="<|endoftext|>",
+ bos_token="<|endoftext|>",
+ eos_token="<|endoftext|>",
+ pad_token=None,
+ add_prefix_space=False,
+ add_bos_token=False,
+ return_token_type_ids=False,
+ **kwargs,
+ ):
+ bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token
+ unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token
+ pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token
+ self.add_bos_token = add_bos_token
+ self.return_token_type_ids = return_token_type_ids
+ if self.return_token_type_ids:
+ self.model_input_names.append("token_type_ids")
+
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ bpe_merges = merges_handle.read().split("\n")[1:-1]
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {}
+ self.add_prefix_space = add_prefix_space
+
+ # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+ super().__init__(
+ errors=errors,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ add_prefix_space=add_prefix_space,
+ add_bos_token=add_bos_token,
+ return_token_type_ids=return_token_type_ids,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self):
+ return len(self.encoder)
+
+ def get_vocab(self):
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ if self.add_bos_token:
+ bos_token_ids = [self.bos_token_id]
+ else:
+ bos_token_ids = []
+
+ output = bos_token_ids + token_ids_0
+
+ if token_ids_1 is None:
+ return output
+
+ return output + bos_token_ids + token_ids_1
+
+ def _tokenize(self, text):
+ """Tokenize a string."""
+ bpe_tokens = []
+ for token in re.findall(self.pat, text):
+ token = "".join(
+ self.byte_encoder[b] for b in token.encode("utf-8")
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ text = "".join(tokens)
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
+ return text
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ merge_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+ )
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write("#version: 0.2\n")
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!"
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return vocab_file, merge_file
+
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
+ add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
+ if is_split_into_words or add_prefix_space:
+ text = " " + text
+ return (text, kwargs)
+
+ def decode(
+ self,
+ token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: Optional[bool] = None,
+ truncate_before_pattern: Optional[list[str]] = None,
+ **kwargs,
+ ) -> str:
+ """
+ Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
+ tokens and clean up tokenization spaces.
+
+ Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
+
+ Args:
+ token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
+ List of tokenized input ids. Can be obtained using the `__call__` method.
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to remove special tokens in the decoding.
+ clean_up_tokenization_spaces (`bool`, *optional*):
+ Whether or not to clean up the tokenization spaces. If `None`, will default to
+ `self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
+ truncate_before_pattern (`List[str]`, *optional*, defaults to `None`):
+ A list of regular expression strings that will be used to truncate the returned string. This can be
+ used to remove extra pieces of code (e.g. truncate if observing a comment symbol "#" at the beginning
+ of a new line). An example pattern could be `["^#", re.escape("<|endoftext|>"), "^'''", "\n\n\n"]`.
+ kwargs (additional keyword arguments, *optional*):
+ Will be passed to the underlying model specific decode method.
+
+ Returns:
+ `str`: The decoded sentence.
+ """
+
+ token_ids = to_py_obj(token_ids)
+
+ decoded_text = super()._decode(
+ token_ids=token_ids,
+ skip_special_tokens=skip_special_tokens,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ if truncate_before_pattern is not None and len(truncate_before_pattern) > 0:
+ decoded_text = self.truncate(decoded_text, truncate_before_pattern)
+
+ return decoded_text
+
+ def truncate(self, completion, truncate_before_pattern):
+ def find_re(string, pattern, start_pos):
+ m = pattern.search(string, start_pos)
+ return m.start() if m else -1
+
+ terminals = [re.compile(pattern, re.MULTILINE) for pattern in truncate_before_pattern]
+
+ prints = list(re.finditer("^print", completion, re.MULTILINE))
+
+ if len(prints) > 1:
+ completion = completion[: prints[1].start()]
+
+ defs = list(re.finditer("^def", completion, re.MULTILINE))
+
+ if len(defs) > 1:
+ completion = completion[: defs[1].start()]
+
+ start_pos = 0
+
+ terminals_pos = [
+ pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1
+ ]
+
+ if len(terminals_pos) > 0:
+ return completion[: min(terminals_pos)]
+ else:
+ return completion
+
+
+__all__ = ["CodeGenTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/tokenization_codegen_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/tokenization_codegen_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bac0db7de4e7c548ddea0eb2b3c498919c6196a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/tokenization_codegen_fast.py
@@ -0,0 +1,235 @@
+# coding=utf-8
+# Copyright 2022 The Salesforce authors, The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for OpenAI GPT."""
+
+import re
+from typing import TYPE_CHECKING, Optional, Union
+
+import numpy as np
+
+from ...utils import is_tf_available, is_torch_available, logging
+
+
+if TYPE_CHECKING:
+ if is_torch_available():
+ import torch
+ if is_tf_available():
+ import tensorflow as tf
+
+
+from ...tokenization_utils_base import BatchEncoding
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from .tokenization_codegen import CodeGenTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
+
+
+class CodeGenTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" CodeGen tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
+ Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```python
+ >>> from transformers import CodeGenTokenizerFast
+
+ >>> tokenizer = CodeGenTokenizerFast.from_pretrained("Salesforce/codegen-350M-mono")
+ >>> tokenizer("Hello world")["input_ids"]
+ [15496, 995]
+
+ >>> tokenizer(" Hello world")["input_ids"]
+ [18435, 995]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
+ the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`, *optional*):
+ Path to the vocabulary file.
+ merges_file (`str`, *optional*):
+ Path to the merges file.
+ tokenizer_file (`str`, *optional*):
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
+ contains everything needed to load the tokenizer.
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The beginning of sequence token.
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The end of sequence token.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (CodeGen tokenizer detect beginning of words by the preceding space).
+ return_token_type_ids (`bool`, *optional*, defaults to `False`):
+ Whether to return token type IDs.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = CodeGenTokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ merges_file=None,
+ tokenizer_file=None,
+ unk_token="<|endoftext|>",
+ bos_token="<|endoftext|>",
+ eos_token="<|endoftext|>",
+ add_prefix_space=False,
+ return_token_type_ids=False,
+ **kwargs,
+ ):
+ self.return_token_type_ids = return_token_type_ids
+ if self.return_token_type_ids:
+ self.model_input_names.append("token_type_ids")
+
+ super().__init__(
+ vocab_file,
+ merges_file,
+ tokenizer_file=tokenizer_file,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ add_prefix_space=add_prefix_space,
+ return_token_type_ids=return_token_type_ids,
+ **kwargs,
+ )
+
+ if kwargs.pop("add_bos_token", False):
+ model_id = kwargs.pop("name_or_path", "")
+ raise ValueError(
+ "Currently GPT2's fast tokenizer does NOT support adding a BOS token. "
+ "Instead you should use GPT2's slow tokenizer class `CodeGenTokenizer` as follows: \n"
+ f"`CodeGenTokenizer.from_pretrained('{model_id}')`\nor\n"
+ f"`AutoTokenizer.from_pretrained('{model_id}', use_fast=False)`\n"
+ "This issue will be fixed soon, see: https://github.com/huggingface/tokenizers/pull/1005."
+ " so that the fast tokenizer works correctly."
+ )
+
+ def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+ assert self.add_prefix_space or not is_split_into_words, (
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
+ "to use it with pretokenized inputs."
+ )
+
+ return super()._batch_encode_plus(*args, **kwargs)
+
+ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+
+ assert self.add_prefix_space or not is_split_into_words, (
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
+ "to use it with pretokenized inputs."
+ )
+
+ return super()._encode_plus(*args, **kwargs)
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+ def decode(
+ self,
+ token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: Optional[bool] = None,
+ truncate_before_pattern: Optional[list[str]] = None,
+ **kwargs,
+ ) -> str:
+ """
+ Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
+ tokens and clean up tokenization spaces.
+
+ Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
+
+ Args:
+ token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
+ List of tokenized input ids. Can be obtained using the `__call__` method.
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to remove special tokens in the decoding.
+ clean_up_tokenization_spaces (`bool`, *optional*):
+ Whether or not to clean up the tokenization spaces. If `None`, will default to
+ `self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
+ truncate_before_pattern (`List[str]`, *optional*, defaults to `None`):
+ A list of regular expression strings that will be used to truncate the returned string. This can be
+ used to remove extra pieces of code (e.g. truncate if observing a comment symbol "#" at the beginning
+ of a new line). An example pattern could be `["^#", re.escape("<|endoftext|>"), "^'''", "\n\n\n"]`.
+ kwargs (additional keyword arguments, *optional*):
+ Will be passed to the underlying model specific decode method.
+
+ Returns:
+ `str`: The decoded sentence.
+ """
+
+ decoded_text = super().decode(
+ token_ids=token_ids,
+ skip_special_tokens=skip_special_tokens,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ if truncate_before_pattern is not None and len(truncate_before_pattern) > 0:
+ decoded_text = self.truncate(decoded_text, truncate_before_pattern)
+
+ return decoded_text
+
+ def truncate(self, completion, truncate_before_pattern):
+ def find_re(string, pattern, start_pos):
+ m = pattern.search(string, start_pos)
+ return m.start() if m else -1
+
+ terminals = [re.compile(pattern, re.MULTILINE) for pattern in truncate_before_pattern]
+
+ prints = list(re.finditer("^print", completion, re.MULTILINE))
+
+ if len(prints) > 1:
+ completion = completion[: prints[1].start()]
+
+ defs = list(re.finditer("^def", completion, re.MULTILINE))
+
+ if len(defs) > 1:
+ completion = completion[: defs[1].start()]
+
+ start_pos = 0
+
+ terminals_pos = [
+ pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1
+ ]
+
+ if len(terminals_pos) > 0:
+ return completion[: min(terminals_pos)]
+ else:
+ return completion
+
+
+__all__ = ["CodeGenTokenizerFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1447f65935601f0fffd8a88dac25bc5916b35f83
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Cohere and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_cohere2 import *
+ from .modeling_cohere2 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/configuration_cohere2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/configuration_cohere2.py
new file mode 100644
index 0000000000000000000000000000000000000000..c92f63cad312651fd750fb12156166f95ce3d8d4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/configuration_cohere2.py
@@ -0,0 +1,232 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/cohere2/modular_cohere2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_cohere2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Cohere Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...modeling_rope_utils import rope_config_validation
+
+
+class Cohere2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`CohereModel`]. It is used to instantiate an Cohere
+ model according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) model.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 256000):
+ Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`CohereModel`]
+ hidden_size (`int`, *optional*, defaults to 8192):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 22528):
+ Dimension of the MLP representations.
+ logit_scale (`float`, *optional*, defaults to 0.0625):
+ The scaling factor for the output logits.
+ num_hidden_layers (`int`, *optional*, defaults to 40):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 64):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 8192):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 5):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 255001):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ sliding_window (`int`, *optional*, defaults to 4096):
+ Size of the sliding window attention context.
+ layer_types (`list`, *optional*):
+ Attention pattern for each layer.
+
+ ```python
+ >>> from transformers import Cohere2Model, Cohere2Config
+
+ >>> # Initializing a Cohere Nextmodel configuration
+ >>> configuration = Cohere2Config()
+
+ >>> # Initializing a model from the Cohere2 configuration
+ >>> model = Cohere2Model(configuration) # doctest: +SKIP
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config # doctest: +SKIP
+ ```
+ """
+
+ model_type = "cohere2"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=256000,
+ hidden_size=8192,
+ intermediate_size=22528,
+ logit_scale=0.0625,
+ num_hidden_layers=40,
+ num_attention_heads=64,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=5,
+ eos_token_id=255001,
+ tie_word_embeddings=True,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ sliding_window=4096,
+ layer_types=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.logit_scale = logit_scale
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.sliding_window = sliding_window
+ self.layer_types = layer_types
+ # Need to specify head_dim in the config so it can be used in the attention forward functions
+ self.head_dim = hidden_size // num_attention_heads
+
+ # Validate the correctness of rotary position embeddings parameters
+ rope_config_validation(self)
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
+ self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 4)
+
+ if self.layer_types is None:
+ # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
+ self._sliding_window_pattern = getattr(self, "sliding_window_pattern", 4)
+ self.layer_types = [
+ "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
+ for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+
+__all__ = ["Cohere2Config"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/modeling_cohere2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/modeling_cohere2.py
new file mode 100644
index 0000000000000000000000000000000000000000..bab804aab67ec2fa1006cb70ac83766e8f2e71aa
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/modeling_cohere2.py
@@ -0,0 +1,513 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/cohere2/modular_cohere2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_cohere2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Cohere Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_cohere2 import Cohere2Config
+
+
+class Cohere2RotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Cohere2Config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class Cohere2LayerNorm(nn.Module):
+ def __init__(self, hidden_size=None, eps=1e-5, bias=False):
+ """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ mean = hidden_states.mean(-1, keepdim=True)
+ variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
+ hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
+ hidden_states = self.weight.to(torch.float32) * hidden_states
+ return hidden_states.to(input_dtype)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def rotate_half(x):
+ # Split and rotate. Note that this function is different from e.g. Llama.
+ x1 = x[..., ::2]
+ x2 = x[..., 1::2]
+ rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
+ return rot_x
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ dtype = q.dtype
+ q = q.float()
+ k = k.float()
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
+
+
+class Cohere2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ if self.sliding_window is not None:
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Cohere2MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class Cohere2DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Cohere2Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = Cohere2Attention(config=config, layer_idx=layer_idx)
+ self.mlp = Cohere2MLP(config)
+ self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
+ self.attention_type = config.layer_types[layer_idx]
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ """
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states_attention, _ = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states_mlp = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states_attention + hidden_states_mlp
+ return hidden_states
+
+
+@auto_docstring
+class Cohere2PreTrainedModel(PreTrainedModel):
+ config: Cohere2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Cohere2DecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": Cohere2DecoderLayer,
+ "attentions": Cohere2Attention,
+ }
+
+
+@auto_docstring
+class Cohere2Model(Cohere2PreTrainedModel):
+ def __init__(self, config: Cohere2Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Cohere2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
+ self.rotary_emb = Cohere2RotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None and not self.training:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers:
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Cohere2Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.logit_scale = config.logit_scale
+ self.tie_word_embeddings = config.tie_word_embeddings
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >> from transformers import AutoTokenizer, Cohere2ForCausalLM
+
+ >> model = Cohere2ForCausalLM.from_pretrained("Cohere2ForAI/c4ai-command-r-v01")
+ >> tokenizer = AutoTokenizer.from_pretrained("Cohere2ForAI/c4ai-command-r-v01")
+
+ >> prompt = "Hey, are you conscious? Can you talk to me?"
+ >> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >> # Generate
+ >> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ logits = logits * self.logit_scale # main diff from Llama
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/modular_cohere2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/modular_cohere2.py
new file mode 100644
index 0000000000000000000000000000000000000000..91ed748e0361841b9e9df272a3feb6884fcd6fd9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/modular_cohere2.py
@@ -0,0 +1,444 @@
+# coding=utf-8
+# Copyright 2024 Cohere Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional
+
+import torch
+import torch.nn as nn
+
+from ...cache_utils import Cache, DynamicCache
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutputWithPast
+from ...modeling_rope_utils import rope_config_validation
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, logging
+from ...utils.deprecation import deprecate_kwarg
+from ..cohere.modeling_cohere import (
+ CohereAttention,
+ CohereDecoderLayer,
+ CohereForCausalLM,
+ CohereLayerNorm,
+ CoherePreTrainedModel,
+ CohereRotaryEmbedding,
+ apply_rotary_pos_emb,
+ eager_attention_forward,
+)
+from ..gemma2.modeling_gemma2 import Gemma2Model
+
+
+logger = logging.get_logger(__name__)
+
+
+class Cohere2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`CohereModel`]. It is used to instantiate an Cohere
+ model according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) model.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 256000):
+ Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`CohereModel`]
+ hidden_size (`int`, *optional*, defaults to 8192):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 22528):
+ Dimension of the MLP representations.
+ logit_scale (`float`, *optional*, defaults to 0.0625):
+ The scaling factor for the output logits.
+ num_hidden_layers (`int`, *optional*, defaults to 40):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 64):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 8192):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 5):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 255001):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ sliding_window (`int`, *optional*, defaults to 4096):
+ Size of the sliding window attention context.
+ layer_types (`list`, *optional*):
+ Attention pattern for each layer.
+
+ ```python
+ >>> from transformers import Cohere2Model, Cohere2Config
+
+ >>> # Initializing a Cohere Nextmodel configuration
+ >>> configuration = Cohere2Config()
+
+ >>> # Initializing a model from the Cohere2 configuration
+ >>> model = Cohere2Model(configuration) # doctest: +SKIP
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config # doctest: +SKIP
+ ```
+ """
+
+ model_type = "cohere2"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=256000,
+ hidden_size=8192,
+ intermediate_size=22528,
+ logit_scale=0.0625,
+ num_hidden_layers=40,
+ num_attention_heads=64,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=5,
+ eos_token_id=255001,
+ tie_word_embeddings=True,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ sliding_window=4096,
+ layer_types=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.logit_scale = logit_scale
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.sliding_window = sliding_window
+ self.layer_types = layer_types
+ # Need to specify head_dim in the config so it can be used in the attention forward functions
+ self.head_dim = hidden_size // num_attention_heads
+
+ # Validate the correctness of rotary position embeddings parameters
+ rope_config_validation(self)
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
+ self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 4)
+
+ if self.layer_types is None:
+ # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
+ self._sliding_window_pattern = getattr(self, "sliding_window_pattern", 4)
+ self.layer_types = [
+ "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
+ for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+
+class Cohere2RotaryEmbedding(CohereRotaryEmbedding):
+ pass
+
+
+class Cohere2LayerNorm(CohereLayerNorm):
+ pass
+
+
+class Cohere2Attention(CohereAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None):
+ nn.Module.__init__(self)
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ if self.sliding_window is not None:
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Cohere2DecoderLayer(CohereDecoderLayer):
+ def __init__(self, config: Cohere2Config, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.attention_type = config.layer_types[layer_idx]
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states_attention, _ = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states_mlp = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states_attention + hidden_states_mlp
+ return hidden_states
+
+
+class Cohere2PreTrainedModel(CoherePreTrainedModel):
+ config: Cohere2Config
+
+
+class Cohere2Model(Gemma2Model):
+ def __init__(self, config: Cohere2Config):
+ super().__init__(config)
+ self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
+ self.rotary_emb = Cohere2RotaryEmbedding(config=config)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None and not self.training:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers:
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+class Cohere2ForCausalLM(CohereForCausalLM):
+ pass
+
+
+__all__ = ["Cohere2Config", "Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2d826745f5b2e011997179ac0dd3d3cfc14389d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_convnext import *
+ from .feature_extraction_convnext import *
+ from .image_processing_convnext import *
+ from .image_processing_convnext_fast import *
+ from .modeling_convnext import *
+ from .modeling_tf_convnext import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/configuration_convnext.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/configuration_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..f54cba58cf296e8c8e3bae70a9f2e2ab21e3c660
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/configuration_convnext.py
@@ -0,0 +1,142 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ConvNeXT model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class ConvNextConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ConvNextModel`]. It is used to instantiate an
+ ConvNeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the ConvNeXT
+ [facebook/convnext-tiny-224](https://huggingface.co/facebook/convnext-tiny-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ patch_size (`int`, *optional*, defaults to 4):
+ Patch size to use in the patch embedding layer.
+ num_stages (`int`, *optional*, defaults to 4):
+ The number of stages in the model.
+ hidden_sizes (`list[int]`, *optional*, defaults to [96, 192, 384, 768]):
+ Dimensionality (hidden size) at each stage.
+ depths (`list[int]`, *optional*, defaults to [3, 3, 9, 3]):
+ Depth (number of blocks) for each stage.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in each block. If string, `"gelu"`, `"relu"`,
+ `"selu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ layer_scale_init_value (`float`, *optional*, defaults to 1e-6):
+ The initial value for the layer scale.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ The drop rate for stochastic depth.
+ out_features (`list[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`list[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+
+ Example:
+ ```python
+ >>> from transformers import ConvNextConfig, ConvNextModel
+
+ >>> # Initializing a ConvNext convnext-tiny-224 style configuration
+ >>> configuration = ConvNextConfig()
+
+ >>> # Initializing a model (with random weights) from the convnext-tiny-224 style configuration
+ >>> model = ConvNextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "convnext"
+
+ def __init__(
+ self,
+ num_channels=3,
+ patch_size=4,
+ num_stages=4,
+ hidden_sizes=None,
+ depths=None,
+ hidden_act="gelu",
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ layer_scale_init_value=1e-6,
+ drop_path_rate=0.0,
+ image_size=224,
+ out_features=None,
+ out_indices=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.num_stages = num_stages
+ self.hidden_sizes = [96, 192, 384, 768] if hidden_sizes is None else hidden_sizes
+ self.depths = [3, 3, 9, 3] if depths is None else depths
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.layer_scale_init_value = layer_scale_init_value
+ self.drop_path_rate = drop_path_rate
+ self.image_size = image_size
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+
+
+class ConvNextOnnxConfig(OnnxConfig):
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-5
+
+
+__all__ = ["ConvNextConfig", "ConvNextOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/feature_extraction_convnext.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/feature_extraction_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fbb5184cf37cb0aedcafeab3a6b363ab047d9a0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/feature_extraction_convnext.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for ConvNeXT."""
+
+import warnings
+
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_convnext import ConvNextImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class ConvNextFeatureExtractor(ConvNextImageProcessor):
+ def __init__(self, *args, **kwargs) -> None:
+ warnings.warn(
+ "The class ConvNextFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
+ " Please use ConvNextImageProcessor instead.",
+ FutureWarning,
+ )
+ super().__init__(*args, **kwargs)
+
+
+__all__ = ["ConvNextFeatureExtractor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/image_processing_convnext.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/image_processing_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..af89274500ddb33160e92b5591bc9ac83c55f24c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/image_processing_convnext.py
@@ -0,0 +1,325 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for ConvNeXT."""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ center_crop,
+ get_resize_output_image_size,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+from ...utils.import_utils import requires
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class ConvNextImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a ConvNeXT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden
+ by `do_resize` in the `preprocess` method.
+ size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`):
+ Resolution of the output image after `resize` is applied. If `size["shortest_edge"]` >= 384, the image is
+ resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the image will
+ be matched to `int(size["shortest_edge"]/crop_pct)`, after which the image is cropped to
+ `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. Can
+ be overridden by `size` in the `preprocess` method.
+ crop_pct (`float` *optional*, defaults to 224 / 256):
+ Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be
+ overridden by `crop_pct` in the `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ crop_pct: Optional[float] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 384}
+ size = get_size_dict(size, default_to_square=False)
+
+ self.do_resize = do_resize
+ self.size = size
+ # Default value set here for backwards compatibility where the value in config is None
+ self.crop_pct = crop_pct if crop_pct is not None else 224 / 256
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ crop_pct: float,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If
+ `size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`.
+ Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`,
+ after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`.
+ crop_pct (`float`):
+ Percentage of the image to crop. Only has an effect if size < 384.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input
+ image.
+ """
+ size = get_size_dict(size, default_to_square=False)
+ if "shortest_edge" not in size:
+ raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}")
+ shortest_edge = size["shortest_edge"]
+
+ if shortest_edge < 384:
+ # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
+ resize_shortest_edge = int(shortest_edge / crop_pct)
+ resize_size = get_resize_output_image_size(
+ image, size=resize_shortest_edge, default_to_square=False, input_data_format=input_data_format
+ )
+ image = resize(
+ image=image,
+ size=resize_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+ # then crop to (shortest_edge, shortest_edge)
+ return center_crop(
+ image=image,
+ size=(shortest_edge, shortest_edge),
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+ else:
+ # warping (no cropping) when evaluated at 384 or larger
+ return resize(
+ image,
+ size=(shortest_edge, shortest_edge),
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ crop_pct: Optional[float] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
+ is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
+ image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
+ `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`.
+ crop_pct (`float`, *optional*, defaults to `self.crop_pct`):
+ Percentage of the image to crop if size < 384.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ crop_pct = crop_pct if crop_pct is not None else self.crop_pct
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if do_resize:
+ images = [
+ self.resize(
+ image=image, size=size, crop_pct=crop_pct, resample=resample, input_data_format=input_data_format
+ )
+ for image in images
+ ]
+
+ if do_rescale:
+ images = [
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_normalize:
+ images = [
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["ConvNextImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/image_processing_convnext_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/image_processing_convnext_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ab00c0fd091369715e636bac663fbbbdc9239a0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/image_processing_convnext_fast.py
@@ -0,0 +1,180 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for ConvNeXT."""
+
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_transforms import get_resize_output_image_size
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ auto_docstring,
+)
+
+
+class ConvNextFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ crop_pct (`float`, *optional*):
+ Percentage of the image to crop. Only has an effect if size < 384. Can be
+ overridden by `crop_pct` in the`preprocess` method.
+ """
+
+ crop_pct: Optional[float]
+
+
+@auto_docstring
+class ConvNextImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ size = {"shortest_edge": 384}
+ default_to_square = False
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ crop_pct = 224 / 256
+ valid_kwargs = ConvNextFastImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[ConvNextFastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[ConvNextFastImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def resize(
+ self,
+ image: "torch.Tensor",
+ size: dict[str, int],
+ crop_pct: float,
+ interpolation: PILImageResampling = PILImageResampling.BICUBIC,
+ **kwargs,
+ ) -> "torch.Tensor":
+ """
+ Resize an image.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If
+ `size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`.
+ Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`,
+ after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`.
+ crop_pct (`float`):
+ Percentage of the image to crop. Only has an effect if size < 384.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resizing the image.
+
+ Returns:
+ `torch.Tensor`: Resized image.
+ """
+ if not size.shortest_edge:
+ raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}")
+ shortest_edge = size["shortest_edge"]
+
+ if shortest_edge < 384:
+ # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
+ resize_shortest_edge = int(shortest_edge / crop_pct)
+ resize_size = get_resize_output_image_size(
+ image, size=resize_shortest_edge, default_to_square=False, input_data_format=ChannelDimension.FIRST
+ )
+ image = F.resize(
+ image,
+ resize_size,
+ interpolation=interpolation,
+ **kwargs,
+ )
+ # then crop to (shortest_edge, shortest_edge)
+ return F.center_crop(
+ image,
+ (shortest_edge, shortest_edge),
+ **kwargs,
+ )
+ else:
+ # warping (no cropping) when evaluated at 384 or larger
+ return F.resize(
+ image,
+ (shortest_edge, shortest_edge),
+ interpolation=interpolation,
+ **kwargs,
+ )
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: dict[str, int],
+ crop_pct: float,
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: int,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(
+ image=stacked_images, size=size, crop_pct=crop_pct, interpolation=interpolation
+ )
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_center_crop:
+ stacked_images = self.center_crop(stacked_images, crop_size)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+
+__all__ = ["ConvNextImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/modeling_convnext.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/modeling_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..3120c140d2ed41c16e659a780c7d4603dcfda7b7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/modeling_convnext.py
@@ -0,0 +1,424 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ConvNext model."""
+
+from typing import Optional
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+ BackboneOutput,
+ BaseModelOutputWithNoAttention,
+ BaseModelOutputWithPoolingAndNoAttention,
+ ImageClassifierOutputWithNoAttention,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging
+from ...utils.backbone_utils import BackboneMixin
+from ...utils.generic import can_return_tuple
+from .configuration_convnext import ConvNextConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext
+class ConvNextDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return f"p={self.drop_prob}"
+
+
+class ConvNextLayerNorm(nn.LayerNorm):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
+ super().__init__(normalized_shape, eps=eps, **kwargs)
+ if data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError(f"Unsupported data format: {data_format}")
+ self.data_format = data_format
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
+ """
+ if self.data_format == "channels_first":
+ features = features.permute(0, 2, 3, 1)
+ features = super().forward(features)
+ features = features.permute(0, 3, 1, 2)
+ else:
+ features = super().forward(features)
+ return features
+
+
+class ConvNextEmbeddings(nn.Module):
+ """This class is comparable to (and inspired by) the SwinEmbeddings class
+ found in src/transformers/models/swin/modeling_swin.py.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.patch_embeddings = nn.Conv2d(
+ config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
+ )
+ self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
+ self.num_channels = config.num_channels
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ embeddings = self.patch_embeddings(pixel_values)
+ embeddings = self.layernorm(embeddings)
+ return embeddings
+
+
+class ConvNextLayer(nn.Module):
+ """This corresponds to the `Block` class in the original implementation.
+
+ There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
+ H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
+
+ The authors used (2) as they find it slightly faster in PyTorch.
+
+ Args:
+ config ([`ConvNextConfig`]): Model configuration class.
+ dim (`int`): Number of input channels.
+ drop_path (`float`): Stochastic depth rate. Default: 0.0.
+ """
+
+ def __init__(self, config, dim, drop_path=0):
+ super().__init__()
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
+ self.layernorm = ConvNextLayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = ACT2FN[config.hidden_act]
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.layer_scale_parameter = (
+ nn.Parameter(config.layer_scale_init_value * torch.ones(dim), requires_grad=True)
+ if config.layer_scale_init_value > 0
+ else None
+ )
+ self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ residual = features
+ features = self.dwconv(features)
+ features = features.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ features = self.layernorm(features)
+ features = self.pwconv1(features)
+ features = self.act(features)
+ features = self.pwconv2(features)
+ if self.layer_scale_parameter is not None:
+ features = self.layer_scale_parameter * features
+ features = features.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+ features = residual + self.drop_path(features)
+ return features
+
+
+class ConvNextStage(nn.Module):
+ """ConvNeXT stage, consisting of an optional downsampling layer + multiple residual blocks.
+
+ Args:
+ config ([`ConvNextConfig`]): Model configuration class.
+ in_channels (`int`): Number of input channels.
+ out_channels (`int`): Number of output channels.
+ depth (`int`): Number of residual blocks.
+ drop_path_rates(`list[float]`): Stochastic depth rates for each layer.
+ """
+
+ def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None):
+ super().__init__()
+
+ if in_channels != out_channels or stride > 1:
+ self.downsampling_layer = nn.ModuleList(
+ [
+ ConvNextLayerNorm(in_channels, eps=1e-6, data_format="channels_first"),
+ nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
+ ]
+ )
+ else:
+ self.downsampling_layer = nn.ModuleList()
+ drop_path_rates = drop_path_rates or [0.0] * depth
+ self.layers = nn.ModuleList(
+ [ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
+ )
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ for layer in self.downsampling_layer:
+ features = layer(features)
+ for layer in self.layers:
+ features = layer(features)
+ return features
+
+
+class ConvNextEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.stages = nn.ModuleList()
+ drop_path_rates = [
+ x.tolist()
+ for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").split(config.depths)
+ ]
+ prev_chs = config.hidden_sizes[0]
+ for i in range(config.num_stages):
+ out_chs = config.hidden_sizes[i]
+ stage = ConvNextStage(
+ config,
+ in_channels=prev_chs,
+ out_channels=out_chs,
+ stride=2 if i > 0 else 1,
+ depth=config.depths[i],
+ drop_path_rates=drop_path_rates[i],
+ )
+ self.stages.append(stage)
+ prev_chs = out_chs
+
+ def forward(
+ self, hidden_states: torch.Tensor, output_hidden_states: Optional[bool] = False
+ ) -> BaseModelOutputWithNoAttention:
+ all_hidden_states = [hidden_states] if output_hidden_states else None
+
+ for layer_module in self.stages:
+ hidden_states = layer_module(hidden_states)
+ if all_hidden_states is not None:
+ all_hidden_states.append(hidden_states)
+
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
+
+
+@auto_docstring
+class ConvNextPreTrainedModel(PreTrainedModel):
+ config: ConvNextConfig
+ base_model_prefix = "convnext"
+ main_input_name = "pixel_values"
+ _no_split_modules = ["ConvNextLayer"]
+ _can_record_outputs = {} # hidden states are collected explicitly
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, ConvNextLayerNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, ConvNextLayer):
+ if module.layer_scale_parameter is not None:
+ module.layer_scale_parameter.data.fill_(self.config.layer_scale_init_value)
+
+
+@auto_docstring
+class ConvNextModel(ConvNextPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = ConvNextEmbeddings(config)
+ self.encoder = ConvNextEncoder(config)
+
+ # final layernorm layer
+ self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None
+ ) -> BaseModelOutputWithPoolingAndNoAttention:
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.embeddings(pixel_values)
+ encoder_outputs: BaseModelOutputWithNoAttention = self.encoder(
+ embedding_output, output_hidden_states=output_hidden_states
+ )
+ last_hidden_state = encoder_outputs.last_hidden_state
+
+ # global average pooling, (N, C, H, W) -> (N, C)
+ pooled_output = self.layernorm(last_hidden_state.mean([-2, -1]))
+
+ return BaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """
+)
+class ConvNextForImageClassification(ConvNextPreTrainedModel):
+ accepts_loss_kwargs = False
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.convnext = ConvNextModel(config)
+
+ # Classifier head
+ if config.num_labels > 0:
+ self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels)
+ else:
+ self.classifier = nn.Identity()
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self, pixel_values: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, **kwargs
+ ) -> ImageClassifierOutputWithNoAttention:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ outputs: BaseModelOutputWithPoolingAndNoAttention = self.convnext(pixel_values, **kwargs)
+ pooled_output = outputs.pooler_output
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels=labels, pooled_logits=logits, config=self.config)
+
+ return ImageClassifierOutputWithNoAttention(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ ConvNeXt backbone, to be used with frameworks like DETR and MaskFormer.
+ """
+)
+class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
+ has_attentions = False
+
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.embeddings = ConvNextEmbeddings(config)
+ self.encoder = ConvNextEncoder(config)
+ self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
+
+ # Add layer norms to hidden states of out_features
+ hidden_states_norms = {}
+ for stage, num_channels in zip(self._out_features, self.channels):
+ hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first")
+ self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
+
+ # initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ ) -> BackboneOutput:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
+ >>> model = AutoBackbone.from_pretrained("facebook/convnext-tiny-224")
+
+ >>> inputs = processor(image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ ```"""
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ embedding_output = self.embeddings(pixel_values)
+ outputs: BaseModelOutputWithPoolingAndNoAttention = self.encoder(embedding_output, output_hidden_states=True)
+ hidden_states = outputs.hidden_states
+
+ feature_maps = []
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
+ feature_maps.append(hidden_state)
+
+ return BackboneOutput(
+ feature_maps=tuple(feature_maps),
+ hidden_states=hidden_states if output_hidden_states else None,
+ )
+
+
+__all__ = ["ConvNextForImageClassification", "ConvNextModel", "ConvNextPreTrainedModel", "ConvNextBackbone"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/modeling_tf_convnext.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/modeling_tf_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..7306877466d9b793682d01d90645f08420930b59
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/modeling_tf_convnext.py
@@ -0,0 +1,667 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 ConvNext model."""
+
+from __future__ import annotations
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
+from ...modeling_tf_utils import (
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from .configuration_convnext import ConvNextConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_CONFIG_FOR_DOC = "ConvNextConfig"
+_CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224"
+
+
+class TFConvNextDropPath(keras.layers.Layer):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ References:
+ (1) github.com:rwightman/pytorch-image-models
+ """
+
+ def __init__(self, drop_path: float, **kwargs):
+ super().__init__(**kwargs)
+ self.drop_path = drop_path
+
+ def call(self, x: tf.Tensor, training=None):
+ if training:
+ keep_prob = 1 - self.drop_path
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
+ random_tensor = tf.floor(random_tensor)
+ return (x / keep_prob) * random_tensor
+ return x
+
+
+class TFConvNextEmbeddings(keras.layers.Layer):
+ """This class is comparable to (and inspired by) the SwinEmbeddings class
+ found in src/transformers/models/swin/modeling_swin.py.
+ """
+
+ def __init__(self, config: ConvNextConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.patch_embeddings = keras.layers.Conv2D(
+ filters=config.hidden_sizes[0],
+ kernel_size=config.patch_size,
+ strides=config.patch_size,
+ name="patch_embeddings",
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer=keras.initializers.Zeros(),
+ )
+ self.layernorm = keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
+ self.num_channels = config.num_channels
+ self.config = config
+
+ def call(self, pixel_values):
+ if isinstance(pixel_values, dict):
+ pixel_values = pixel_values["pixel_values"]
+
+ tf.debugging.assert_equal(
+ shape_list(pixel_values)[1],
+ self.num_channels,
+ message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.",
+ )
+
+ # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
+ # So change the input format from `NCHW` to `NHWC`.
+ # shape = (batch_size, in_height, in_width, in_channels)
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+
+ embeddings = self.patch_embeddings(pixel_values)
+ embeddings = self.layernorm(embeddings)
+ return embeddings
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "patch_embeddings", None) is not None:
+ with tf.name_scope(self.patch_embeddings.name):
+ self.patch_embeddings.build([None, None, None, self.config.num_channels])
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, None, None, self.config.hidden_sizes[0]])
+
+
+class TFConvNextLayer(keras.layers.Layer):
+ """This corresponds to the `Block` class in the original implementation.
+
+ There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
+ H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
+
+ The authors used (2) as they find it slightly faster in PyTorch. Since we already permuted the inputs to follow
+ NHWC ordering, we can just apply the operations straight-away without the permutation.
+
+ Args:
+ config ([`ConvNextConfig`]): Model configuration class.
+ dim (`int`): Number of input channels.
+ drop_path (`float`): Stochastic depth rate. Default: 0.0.
+ """
+
+ def __init__(self, config, dim, drop_path=0.0, **kwargs):
+ super().__init__(**kwargs)
+ self.dim = dim
+ self.config = config
+ self.dwconv = keras.layers.Conv2D(
+ filters=dim,
+ kernel_size=7,
+ padding="same",
+ groups=dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="dwconv",
+ ) # depthwise conv
+ self.layernorm = keras.layers.LayerNormalization(
+ epsilon=1e-6,
+ name="layernorm",
+ )
+ self.pwconv1 = keras.layers.Dense(
+ units=4 * dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="pwconv1",
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = get_tf_activation(config.hidden_act)
+ self.pwconv2 = keras.layers.Dense(
+ units=dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="pwconv2",
+ )
+ # Using `layers.Activation` instead of `tf.identity` to better control `training`
+ # behaviour.
+ self.drop_path = (
+ TFConvNextDropPath(drop_path, name="drop_path")
+ if drop_path > 0.0
+ else keras.layers.Activation("linear", name="drop_path")
+ )
+
+ def build(self, input_shape: tf.TensorShape = None):
+ # PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
+ self.layer_scale_parameter = (
+ self.add_weight(
+ shape=(self.dim,),
+ initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
+ trainable=True,
+ name="layer_scale_parameter",
+ )
+ if self.config.layer_scale_init_value > 0
+ else None
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dwconv", None) is not None:
+ with tf.name_scope(self.dwconv.name):
+ self.dwconv.build([None, None, None, self.dim])
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, None, None, self.dim])
+ if getattr(self, "pwconv1", None) is not None:
+ with tf.name_scope(self.pwconv1.name):
+ self.pwconv1.build([None, None, self.dim])
+ if getattr(self, "pwconv2", None) is not None:
+ with tf.name_scope(self.pwconv2.name):
+ self.pwconv2.build([None, None, 4 * self.dim])
+ if getattr(self, "drop_path", None) is not None:
+ with tf.name_scope(self.drop_path.name):
+ self.drop_path.build(None)
+
+ def call(self, hidden_states, training=False):
+ input = hidden_states
+ x = self.dwconv(hidden_states)
+ x = self.layernorm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+
+ if self.layer_scale_parameter is not None:
+ x = self.layer_scale_parameter * x
+
+ x = input + self.drop_path(x, training=training)
+ return x
+
+
+class TFConvNextStage(keras.layers.Layer):
+ """ConvNext stage, consisting of an optional downsampling layer + multiple residual blocks.
+
+ Args:
+ config (`ConvNextV2Config`):
+ Model configuration class.
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`):
+ Number of output channels.
+ depth (`int`):
+ Number of residual blocks.
+ drop_path_rates(`list[float]`):
+ Stochastic depth rates for each layer.
+ """
+
+ def __init__(
+ self,
+ config: ConvNextConfig,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 2,
+ stride: int = 2,
+ depth: int = 2,
+ drop_path_rates: list[float] | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if in_channels != out_channels or stride > 1:
+ self.downsampling_layer = [
+ keras.layers.LayerNormalization(
+ epsilon=1e-6,
+ name="downsampling_layer.0",
+ ),
+ # Inputs to this layer will follow NHWC format since we
+ # transposed the inputs from NCHW to NHWC in the `TFConvNextEmbeddings`
+ # layer. All the outputs throughout the model will be in NHWC
+ # from this point on until the output where we again change to
+ # NCHW.
+ keras.layers.Conv2D(
+ filters=out_channels,
+ kernel_size=kernel_size,
+ strides=stride,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer=keras.initializers.Zeros(),
+ name="downsampling_layer.1",
+ ),
+ ]
+ else:
+ self.downsampling_layer = [tf.identity]
+
+ drop_path_rates = drop_path_rates or [0.0] * depth
+ self.layers = [
+ TFConvNextLayer(
+ config,
+ dim=out_channels,
+ drop_path=drop_path_rates[j],
+ name=f"layers.{j}",
+ )
+ for j in range(depth)
+ ]
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.stride = stride
+
+ def call(self, hidden_states):
+ for layer in self.downsampling_layer:
+ hidden_states = layer(hidden_states)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layers", None) is not None:
+ for layer in self.layers:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+ if self.in_channels != self.out_channels or self.stride > 1:
+ with tf.name_scope(self.downsampling_layer[0].name):
+ self.downsampling_layer[0].build([None, None, None, self.in_channels])
+ with tf.name_scope(self.downsampling_layer[1].name):
+ self.downsampling_layer[1].build([None, None, None, self.in_channels])
+
+
+class TFConvNextEncoder(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ self.stages = []
+ drop_path_rates = tf.linspace(0.0, config.drop_path_rate, sum(config.depths))
+ drop_path_rates = tf.split(drop_path_rates, config.depths)
+ drop_path_rates = [x.numpy().tolist() for x in drop_path_rates]
+ prev_chs = config.hidden_sizes[0]
+ for i in range(config.num_stages):
+ out_chs = config.hidden_sizes[i]
+ stage = TFConvNextStage(
+ config,
+ in_channels=prev_chs,
+ out_channels=out_chs,
+ stride=2 if i > 0 else 1,
+ depth=config.depths[i],
+ drop_path_rates=drop_path_rates[i],
+ name=f"stages.{i}",
+ )
+ self.stages.append(stage)
+ prev_chs = out_chs
+
+ def call(self, hidden_states, output_hidden_states=False, return_dict=True):
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, layer_module in enumerate(self.stages):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states = layer_module(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+
+ return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
+
+ def build(self, input_shape=None):
+ for stage in self.stages:
+ with tf.name_scope(stage.name):
+ stage.build(None)
+
+
+@keras_serializable
+class TFConvNextMainLayer(keras.layers.Layer):
+ config_class = ConvNextConfig
+
+ def __init__(self, config: ConvNextConfig, add_pooling_layer: bool = True, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.embeddings = TFConvNextEmbeddings(config, name="embeddings")
+ self.encoder = TFConvNextEncoder(config, name="encoder")
+ self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+ # We are setting the `data_format` like so because from here on we will revert to the
+ # NCHW output format
+ self.pooler = keras.layers.GlobalAvgPool2D(data_format="channels_first") if add_pooling_layer else None
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.embeddings(pixel_values, training=training)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ # Change to NCHW output format have uniformity in the modules
+ last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
+ pooled_output = self.layernorm(self.pooler(last_hidden_state))
+
+ # Change the other hidden state outputs to NCHW as well
+ if output_hidden_states:
+ hidden_states = tuple(tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1])
+
+ if not return_dict:
+ hidden_states = hidden_states if output_hidden_states else ()
+ return (last_hidden_state, pooled_output) + hidden_states
+
+ return TFBaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, self.config.hidden_sizes[-1]])
+
+
+class TFConvNextPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ConvNextConfig
+ base_model_prefix = "convnext"
+ main_input_name = "pixel_values"
+
+
+CONVNEXT_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CONVNEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`ConvNextImageProcessor.__call__`] for details.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+"""
+
+
+@add_start_docstrings(
+ "The bare ConvNext model outputting raw features without any specific head on top.",
+ CONVNEXT_START_DOCSTRING,
+)
+class TFConvNextModel(TFConvNextPreTrainedModel):
+ def __init__(self, config, *inputs, add_pooling_layer=True, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.convnext = TFConvNextMainLayer(config, add_pooling_layer=add_pooling_layer, name="convnext")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, TFConvNextModel
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
+ >>> model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224")
+
+ >>> inputs = image_processor(images=image, return_tensors="tf")
+ >>> outputs = model(**inputs)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ outputs = self.convnext(
+ pixel_values=pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return (outputs[0],) + outputs[1:]
+
+ return TFBaseModelOutputWithPooling(
+ last_hidden_state=outputs.last_hidden_state,
+ pooler_output=outputs.pooler_output,
+ hidden_states=outputs.hidden_states,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "convnext", None) is not None:
+ with tf.name_scope(self.convnext.name):
+ self.convnext.build(None)
+
+
+@add_start_docstrings(
+ """
+ ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """,
+ CONVNEXT_START_DOCSTRING,
+)
+class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: ConvNextConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+ self.convnext = TFConvNextMainLayer(config, name="convnext")
+
+ # Classifier head
+ self.classifier = keras.layers.Dense(
+ units=config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="classifier",
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, TFConvNextForImageClassification
+ >>> import tensorflow as tf
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
+ >>> model = TFConvNextForImageClassification.from_pretrained("facebook/convnext-tiny-224")
+
+ >>> inputs = image_processor(images=image, return_tensors="tf")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the 1000 ImageNet classes
+ >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
+ >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
+ ```"""
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ outputs = self.convnext(
+ pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.classifier(pooled_output)
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "convnext", None) is not None:
+ with tf.name_scope(self.convnext.name):
+ self.convnext.build(None)
+ if getattr(self, "classifier", None) is not None:
+ if hasattr(self.classifier, "name"):
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_sizes[-1]])
+
+
+__all__ = ["TFConvNextForImageClassification", "TFConvNextModel", "TFConvNextPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7000ac3d353bf4eba157b07e350b9ac5f7552a98
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_data2vec_audio import *
+ from .configuration_data2vec_text import *
+ from .configuration_data2vec_vision import *
+ from .modeling_data2vec_audio import *
+ from .modeling_data2vec_text import *
+ from .modeling_data2vec_vision import *
+ from .modeling_tf_data2vec_vision import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_audio.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d88a9de6543afaa60c6d5353f2c32d871d5ee21
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_audio.py
@@ -0,0 +1,288 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Data2VecText configuration"""
+
+import math
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Data2VecAudioConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Data2VecAudioModel`]. It is used to instantiate
+ an Data2VecAudio model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the Data2VecAudio
+ [facebook/data2vec-audio-base-960h](https://huggingface.co/facebook/data2vec-audio-base-960h) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32):
+ Vocabulary size of the Data2VecAudio model. Defines the number of different tokens that can be represented
+ by the `inputs_ids` passed when calling [`Data2VecAudioModel`] or [`TFData2VecAudioModel`]. Vocabulary size
+ of the model. Defines the different tokens that can be represented by the *inputs_ids* passed to the
+ forward method of [`Data2VecAudioModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ activation_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for activations inside the fully connected layer.
+ attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ final_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for the final projection layer of [`Data2VecAudioForCTC`].
+ layerdrop (`float`, *optional*, defaults to 0.1):
+ The LayerDrop probability. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) for more
+ details.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for output of the feature encoder.
+ feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the 1D convolutional layers of the feature
+ extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ conv_dim (`tuple[int]` or `list[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
+ feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
+ conv_stride (`tuple[int]` or `list[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+ A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
+ conv_kernel (`tuple[int]` or `list[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
+ *conv_dim*.
+ conv_bias (`bool`, *optional*, defaults to `False`):
+ Whether the 1D convolutional layers have a bias.
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
+ embeddings layer.
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
+ Number of groups of 1D convolutional positional embeddings layer.
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
+ procedure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
+ reasoning from the probability of each feature vector to be chosen as the start of the vector span to be
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
+ mask_time_length (`int`, *optional*, defaults to 10):
+ Length of vector span along the time axis.
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
+ mask_time_min_masks''
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
+ masking procedure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
+ the axis. If reasoning from the probability of each feature vector to be chosen as the start of the vector
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
+ True`.
+ mask_feature_length (`int`, *optional*, defaults to 10):
+ Length of vector span along the feature axis.
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
+ step, irrespectively of `mask_feature_prob`. Only relevant if
+ ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
+ instance of [`Data2VecAudioForCTC`].
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
+ of [`Data2VecAudioForCTC`].
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
+ instance of [`Data2VecAudioForSequenceClassification`].
+ classifier_proj_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the projection before token mean-pooling for classification.
+ tdnn_dim (`tuple[int]` or `list[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
+ A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
+ module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
+ tdnn_kernel (`tuple[int]` or `list[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
+ *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
+ tdnn_dilation (`tuple[int]` or `list[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
+ A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
+ *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
+ xvector_output_dim (`int`, *optional*, defaults to 512):
+ Dimensionality of the *XVector* embedding vectors.
+ add_adapter (`bool`, *optional*, defaults to `False`):
+ Whether a convolutional network should be stacked on top of the Data2VecAudio Encoder. Can be very useful
+ for warm-starting Data2VecAudio for SpeechEncoderDecoder models.
+ adapter_kernel_size (`int`, *optional*, defaults to 3):
+ Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
+ adapter_stride (`int`, *optional*, defaults to 2):
+ Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
+ num_adapter_layers (`int`, *optional*, defaults to 3):
+ Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
+ True`.
+ output_hidden_size (`int`, *optional*):
+ Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
+ if `add_adapter is True`.
+
+ Example:
+
+ ```python
+ >>> from transformers import Data2VecAudioConfig, Data2VecAudioModel
+
+ >>> # Initializing a Data2VecAudio facebook/data2vec-audio-base-960h style configuration
+ >>> configuration = Data2VecAudioConfig()
+
+ >>> # Initializing a model (with random weights) from the facebook/data2vec-audio-base-960h style configuration
+ >>> model = Data2VecAudioModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "data2vec-audio"
+
+ def __init__(
+ self,
+ vocab_size=32,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout=0.1,
+ activation_dropout=0.1,
+ attention_dropout=0.1,
+ feat_proj_dropout=0.0,
+ final_dropout=0.1,
+ layerdrop=0.1,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ feat_extract_activation="gelu",
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
+ conv_kernel=(10, 3, 3, 3, 3, 2, 2),
+ conv_bias=False,
+ num_conv_pos_embedding_groups=16,
+ conv_pos_kernel_size=19,
+ num_conv_pos_embeddings=5,
+ mask_time_prob=0.05,
+ mask_time_length=10,
+ mask_time_min_masks=2,
+ mask_feature_prob=0.0,
+ mask_feature_length=10,
+ mask_feature_min_masks=0,
+ ctc_loss_reduction="sum",
+ ctc_zero_infinity=False,
+ use_weighted_layer_sum=False,
+ classifier_proj_size=256,
+ tdnn_dim=(512, 512, 512, 512, 1500),
+ tdnn_kernel=(5, 3, 3, 1, 1),
+ tdnn_dilation=(1, 2, 3, 1, 1),
+ xvector_output_dim=512,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ add_adapter=False,
+ adapter_kernel_size=3,
+ adapter_stride=2,
+ num_adapter_layers=3,
+ output_hidden_size=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
+ self.hidden_size = hidden_size
+ self.feat_extract_activation = feat_extract_activation
+ self.conv_dim = list(conv_dim)
+ self.conv_stride = list(conv_stride)
+ self.conv_kernel = list(conv_kernel)
+ self.conv_bias = conv_bias
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
+ self.conv_pos_kernel_size = conv_pos_kernel_size
+ self.num_feat_extract_layers = len(self.conv_dim)
+ self.num_hidden_layers = num_hidden_layers
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.num_attention_heads = num_attention_heads
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.feat_proj_dropout = feat_proj_dropout
+ self.final_dropout = final_dropout
+ self.layerdrop = layerdrop
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ self.vocab_size = vocab_size
+ self.use_weighted_layer_sum = use_weighted_layer_sum
+
+ if (
+ (len(self.conv_stride) != self.num_feat_extract_layers)
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
+ ):
+ raise ValueError(
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ )
+
+ # fine-tuning config parameters for SpecAugment: https://huggingface.co/papers/1904.08779
+ self.mask_time_prob = mask_time_prob
+ self.mask_time_length = mask_time_length
+ self.mask_time_min_masks = mask_time_min_masks
+ self.mask_feature_prob = mask_feature_prob
+ self.mask_feature_length = mask_feature_length
+ self.mask_feature_min_masks = mask_feature_min_masks
+
+ # ctc loss
+ self.ctc_loss_reduction = ctc_loss_reduction
+ self.ctc_zero_infinity = ctc_zero_infinity
+
+ # adapter
+ self.add_adapter = add_adapter
+ self.adapter_kernel_size = adapter_kernel_size
+ self.adapter_stride = adapter_stride
+ self.num_adapter_layers = num_adapter_layers
+ self.output_hidden_size = output_hidden_size or hidden_size
+
+ # SequenceClassification-specific parameter. Feel free to ignore for other classes.
+ self.classifier_proj_size = classifier_proj_size
+
+ # XVector-specific parameters. Feel free to ignore for other classes.
+ self.tdnn_dim = list(tdnn_dim)
+ self.tdnn_kernel = list(tdnn_kernel)
+ self.tdnn_dilation = list(tdnn_dilation)
+ self.xvector_output_dim = xvector_output_dim
+
+ @property
+ def inputs_to_logits_ratio(self):
+ return math.prod(self.conv_stride)
+
+
+__all__ = ["Data2VecAudioConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_text.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9518d67bf665f01a2cfb46cfdb6f529f5f22bce
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_text.py
@@ -0,0 +1,154 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Data2VecText configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Data2VecTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Data2VecTextModel`] and [`Data2VecTextModel`]. It
+ is used to instantiate a Data2VecText model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Data2VecText
+ [facebook/data2vec-text-base](https://huggingface.co/facebook/data2vec-text-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the DATA2VEC model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`Data2VecModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`Data2VecModel`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
+ is_decoder (`bool`, *optional*, defaults to `False`):
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ classifier_dropout (`float`, *optional*):
+ The dropout ratio for the classification head.
+
+ Examples:
+
+ ```python
+ >>> from transformers import Data2VecTextConfig, Data2VecTextModel
+
+ >>> # Initializing a Data2VecText facebook/data2vec-text-base style configuration
+ >>> configuration = Data2VecTextConfig()
+
+ >>> # Initializing a model (with random weights) from the facebook/data2vec-text-base style configuration
+ >>> model = Data2VecTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "data2vec-text"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ position_embedding_type="absolute",
+ use_cache=True,
+ classifier_dropout=None,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.position_embedding_type = position_embedding_type
+ self.use_cache = use_cache
+ self.classifier_dropout = classifier_dropout
+
+
+class Data2VecTextOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ]
+ )
+
+
+__all__ = ["Data2VecTextConfig", "Data2VecTextOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_vision.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..2de256f9d7d7abda3b065472f116a371c03ee4da
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_vision.py
@@ -0,0 +1,194 @@
+# coding=utf-8
+# Copyright Meta Platforms and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Data2VecVision model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Data2VecVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Data2VecVisionModel`]. It is used to instantiate
+ an Data2VecVision model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the Data2VecVision
+ [facebook/data2vec-vision-base](https://huggingface.co/facebook/data2vec-vision-base) architecture.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ use_mask_token (`bool`, *optional*, defaults to `False`):
+ Whether to use a mask token for masked image modeling.
+ use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to use BERT-style absolute position embeddings.
+ use_relative_position_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use T5-style relative position embeddings in the self-attention layers.
+ use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use the same relative position embeddings across all self-attention layers of the Transformer.
+ layer_scale_init_value (`float`, *optional*, defaults to 0.1):
+ Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
+ use_mean_pooling (`bool`, *optional*, defaults to `True`):
+ Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
+ CLS token, before applying the classification head.
+ out_indices (`list[int]`, *optional*, defaults to `[3, 5, 7, 11]`):
+ Indices of the feature maps to use for semantic segmentation.
+ pool_scales (`tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`):
+ Pooling scales used in Pooling Pyramid Module applied on the last feature map.
+ use_auxiliary_head (`bool`, *optional*, defaults to `True`):
+ Whether to use an auxiliary head during training.
+ auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
+ Weight of the cross-entropy loss of the auxiliary head.
+ auxiliary_channels (`int`, *optional*, defaults to 256):
+ Number of channels to use in the auxiliary head.
+ auxiliary_num_convs (`int`, *optional*, defaults to 1):
+ Number of convolutional layers to use in the auxiliary head.
+ auxiliary_concat_input (`bool`, *optional*, defaults to `False`):
+ Whether to concatenate the output of the auxiliary head with the input before the classification layer.
+ semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
+ The index that is ignored by the loss function of the semantic segmentation model.
+
+ Example:
+
+ ```python
+ >>> from transformers import Data2VecVisionConfig, Data2VecVisionModel
+
+ >>> # Initializing a Data2VecVision data2vec_vision-base-patch16-224-in22k style configuration
+ >>> configuration = Data2VecVisionConfig()
+
+ >>> # Initializing a model (with random weights) from the data2vec_vision-base-patch16-224-in22k style configuration
+ >>> model = Data2VecVisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "data2vec-vision"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ use_mask_token=False,
+ use_absolute_position_embeddings=False,
+ use_relative_position_bias=False,
+ use_shared_relative_position_bias=False,
+ layer_scale_init_value=0.1,
+ drop_path_rate=0.1,
+ use_mean_pooling=True,
+ out_indices=[3, 5, 7, 11],
+ pool_scales=[1, 2, 3, 6],
+ use_auxiliary_head=True,
+ auxiliary_loss_weight=0.4,
+ auxiliary_channels=256,
+ auxiliary_num_convs=1,
+ auxiliary_concat_input=False,
+ semantic_loss_ignore_index=255,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.use_mask_token = use_mask_token
+ self.use_absolute_position_embeddings = use_absolute_position_embeddings
+ self.use_relative_position_bias = use_relative_position_bias
+ self.use_shared_relative_position_bias = use_shared_relative_position_bias
+ self.layer_scale_init_value = layer_scale_init_value
+ self.drop_path_rate = drop_path_rate
+ self.use_mean_pooling = use_mean_pooling
+ # decode head attributes (semantic segmentation)
+ self.out_indices = out_indices
+ self.pool_scales = pool_scales
+ # auxiliary head attributes (semantic segmentation)
+ self.use_auxiliary_head = use_auxiliary_head
+ self.auxiliary_loss_weight = auxiliary_loss_weight
+ self.auxiliary_channels = auxiliary_channels
+ self.auxiliary_num_convs = auxiliary_num_convs
+ self.auxiliary_concat_input = auxiliary_concat_input
+ self.semantic_loss_ignore_index = semantic_loss_ignore_index
+
+
+# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
+class Data2VecVisionOnnxConfig(OnnxConfig):
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-4
+
+
+__all__ = ["Data2VecVisionConfig", "Data2VecVisionOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_audio.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9b3f01f42d4ee528c11e12865d0e020ed5f0cf7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_audio.py
@@ -0,0 +1,1397 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/data2vec/modular_data2vec_audio.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_data2vec_audio.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import warnings
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...integrations.fsdp import is_fsdp_managed_module
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available
+from .configuration_data2vec_audio import Data2VecAudioConfig
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+class Data2VecAudioConvLayer(GradientCheckpointingLayer):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ hidden_states = hidden_states.transpose(-2, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(-2, -1)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class Data2VecAudioPadLayer(nn.Module):
+ def __init__(self, num_conv_pos_embeddings):
+ super().__init__()
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
+
+ def forward(self, hidden_states):
+ if self.num_pad_remove > 0:
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
+ return hidden_states
+
+
+class Data2VecAudioPositionalConvLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.conv_pos_kernel_size,
+ padding=config.conv_pos_kernel_size // 2,
+ groups=config.num_conv_pos_embedding_groups,
+ )
+
+ self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size)
+ self.activation = ACT2FN[config.feat_extract_activation]
+ # no learnable parameters
+ self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.padding(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class Data2VecAudioPositionalConvEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)]
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.transpose(1, 2)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class Data2VecAudioFeatureEncoder(nn.Module):
+ """Construct the features from raw audio waveform"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.conv_layers = nn.ModuleList(
+ [Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
+ )
+ self.gradient_checkpointing = False
+ self._requires_grad = True
+
+ def _freeze_parameters(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self._requires_grad = False
+
+ def forward(self, input_values):
+ hidden_states = input_values[:, None]
+
+ # make sure hidden_states require grad for gradient_checkpointing
+ if self._requires_grad and self.training:
+ hidden_states.requires_grad = True
+
+ for conv_layer in self.conv_layers:
+ hidden_states = conv_layer(hidden_states)
+
+ return hidden_states
+
+
+class Data2VecAudioFeatureProjection(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
+
+ def forward(self, hidden_states):
+ # non-projected hidden states are needed for quantization
+ norm_hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.projection(norm_hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states, norm_hidden_states
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Data2VecAudioAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ is_causal: bool = False,
+ config: Optional[Data2VecAudioConfig] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights, None
+
+
+class Data2VecAudioFeedForward(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
+
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.intermediate_dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ hidden_states = self.intermediate_dropout(hidden_states)
+
+ hidden_states = self.output_dense(hidden_states)
+ hidden_states = self.output_dropout(hidden_states)
+ return hidden_states
+
+
+class Data2VecAudioEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config):
+ super().__init__()
+ self.attention = Data2VecAudioAttention(
+ embed_dim=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=False,
+ config=config,
+ )
+
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.feed_forward = Data2VecAudioFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
+ attn_residual = hidden_states
+ hidden_states, attn_weights, _ = self.attention(
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
+ )
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = attn_residual + hidden_states
+
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states + self.feed_forward(hidden_states)
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class Data2VecAudioEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.pos_conv_embed = Data2VecAudioPositionalConvEmbedding(config)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ # make sure padded tokens output 0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
+
+ position_embeddings = self.pos_conv_embed(hidden_states)
+ hidden_states = hidden_states + position_embeddings
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
+
+ for layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ dropout_probability = torch.rand([])
+
+ skip_the_layer = self.training and dropout_probability < self.config.layerdrop
+ if not skip_the_layer or synced_gpus:
+ # under fsdp or deepspeed zero3 all gpus must run in sync
+ layer_outputs = layer(
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
+ )
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+
+class Data2VecAudioAdapterLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.output_hidden_size,
+ 2 * config.output_hidden_size,
+ config.adapter_kernel_size,
+ stride=config.adapter_stride,
+ padding=1,
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
+
+ return hidden_states
+
+
+class Data2VecAudioAdapter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ # feature dim might need to be down-projected
+ if config.output_hidden_size != config.hidden_size:
+ self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
+ self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
+ else:
+ self.proj = self.proj_layer_norm = None
+
+ self.layers = nn.ModuleList(Data2VecAudioAdapterLayer(config) for _ in range(config.num_adapter_layers))
+ self.layerdrop = config.layerdrop
+
+ def forward(self, hidden_states):
+ # down project hidden_states if necessary
+ if self.proj is not None and self.proj_layer_norm is not None:
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.proj_layer_norm(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+
+ for layer in self.layers:
+ layerdrop_prob = np.random.random()
+ if not self.training or (layerdrop_prob > self.layerdrop):
+ hidden_states = layer(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+@auto_docstring
+class Data2VecAudioPreTrainedModel(PreTrainedModel):
+ config: Data2VecAudioConfig
+ base_model_prefix = "data2vec_audio"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, Data2VecAudioFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, Data2VecAudioPositionalConvLayer):
+ nn.init.constant_(module.conv.bias, 0)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ if module.bias is not None:
+ module.bias.data.zero_()
+ if module.weight is not None:
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.kaiming_normal_(module.weight)
+
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+
+ def _get_feat_extract_output_lengths(
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
+ ):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+ if add_adapter:
+ for _ in range(self.config.num_adapter_layers):
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
+ ):
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
+ # on inference mode.
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
+
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
+ output_lengths = output_lengths.to(torch.long)
+
+ batch_size = attention_mask.shape[0]
+
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ # these two operations makes sure that all values before the output lengths idxs are attended to
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+
+def _compute_mask_indices(
+ shape: tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.detach().sum(-1).tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
+Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput
+
+
+@auto_docstring
+class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
+ def __init__(self, config: Data2VecAudioConfig):
+ super().__init__(config)
+ self.config = config
+ self.feature_extractor = Data2VecAudioFeatureEncoder(config)
+ self.feature_projection = Data2VecAudioFeatureProjection(config)
+
+ # model only needs masking vector if mask prob is > 0.0
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
+
+ self.encoder = Data2VecAudioEncoder(config)
+
+ self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.feature_extractor._freeze_parameters()
+
+ def _mask_hidden_states(
+ self,
+ hidden_states: torch.FloatTensor,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ """
+ Masks extracted features along time axis and/or along feature axis according to
+ [SpecAugment](https://huggingface.co/papers/1904.08779).
+ """
+
+ # `config.apply_spec_augment` can set masking to False
+ if not getattr(self.config, "apply_spec_augment", True):
+ return hidden_states
+
+ # generate indices & apply SpecAugment along time axis
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ if mask_time_indices is not None:
+ # apply SpecAugment along time axis with given mask_time_indices
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+ elif self.config.mask_time_prob > 0 and self.training:
+ mask_time_indices = _compute_mask_indices(
+ (batch_size, sequence_length),
+ mask_prob=self.config.mask_time_prob,
+ mask_length=self.config.mask_time_length,
+ attention_mask=attention_mask,
+ min_masks=self.config.mask_time_min_masks,
+ )
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+
+ if self.config.mask_feature_prob > 0 and self.training:
+ # generate indices & apply SpecAugment along feature axis
+ mask_feature_indices = _compute_mask_indices(
+ (batch_size, hidden_size),
+ mask_prob=self.config.mask_feature_prob,
+ mask_length=self.config.mask_feature_length,
+ min_masks=self.config.mask_feature_min_masks,
+ )
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
+ hidden_states[mask_feature_indices] = 0
+
+ return hidden_states
+
+ @auto_docstring
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, Data2VecAudioBaseModelOutput]:
+ r"""
+ mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
+ masked extracted features in *config.proj_codevector_dim* space.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(
+ extract_features.shape[1], attention_mask, add_adapter=False
+ )
+
+ hidden_states, extract_features = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if self.adapter is not None:
+ hidden_states = self.adapter(hidden_states)
+
+ if not return_dict:
+ return (hidden_states, extract_features) + encoder_outputs[1:]
+
+ return Data2VecAudioBaseModelOutput(
+ last_hidden_state=hidden_states,
+ extract_features=extract_features,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+_HIDDEN_STATES_START_POSITION = 2
+
+
+@auto_docstring(
+ custom_intro="""
+ Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
+ """
+)
+class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel):
+ def __init__(self, config):
+ r"""
+ target_lang (`str`, *optional*):
+ Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or
+ adapter..bin. Only relevant when using an instance of [`Data2VecAudioForCTC`] with adapters. Uses 'eng' by
+ default.
+ """
+ super().__init__(config)
+
+ self.data2vec_audio = Data2VecAudioModel(config)
+ self.dropout = nn.Dropout(config.final_dropout)
+
+ if config.vocab_size is None:
+ raise ValueError(
+ f"You are trying to instantiate {self.__class__} with a configuration that "
+ "does not define the vocabulary size of the language model head. Please "
+ "instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+ "or define `vocab_size` of your model's configuration."
+ )
+ output_hidden_size = (
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
+ )
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def freeze_feature_extractor(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ warnings.warn(
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
+ FutureWarning,
+ )
+ self.freeze_feature_encoder()
+
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.data2vec_audio.feature_extractor._freeze_parameters()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, CausalLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None and labels.max() >= self.config.vocab_size:
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+ outputs = self.data2vec_audio(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.dropout(hidden_states)
+
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # retrieve loss input_lengths from attention_mask
+ attention_mask = (
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
+ )
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+
+ # assuming that padded tokens are filled with -100
+ # when not being attended to
+ labels_mask = labels >= 0
+ target_lengths = labels_mask.sum(-1)
+ flattened_targets = labels.masked_select(labels_mask)
+
+ # ctc_loss doesn't support fp16
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = nn.functional.ctc_loss(
+ log_probs,
+ flattened_targets,
+ input_lengths,
+ target_lengths,
+ blank=self.config.pad_token_id,
+ reduction=self.config.ctc_loss_reduction,
+ zero_infinity=self.config.ctc_zero_infinity,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Data2VecAudio Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
+ SUPERB Keyword Spotting.
+ """
+)
+class Data2VecAudioForSequenceClassification(Data2VecAudioPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, "add_adapter") and config.add_adapter:
+ raise ValueError(
+ "Sequence classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)"
+ )
+ self.data2vec_audio = Data2VecAudioModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def freeze_feature_extractor(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
+ not be updated during training.
+ """
+ warnings.warn(
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
+ FutureWarning,
+ )
+ self.freeze_feature_encoder()
+
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.data2vec_audio.feature_extractor._freeze_parameters()
+
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.data2vec_audio.parameters():
+ param.requires_grad = False
+
+ @auto_docstring
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, SequenceClassifierOutput]:
+ r"""
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+ into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
+ (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
+ To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
+ into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.data2vec_audio(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+ if attention_mask is None:
+ pooled_output = hidden_states.mean(dim=1)
+ else:
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
+ expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_padding_mask] = 0.0
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, "add_adapter") and config.add_adapter:
+ raise ValueError(
+ "Audio frame classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)"
+ )
+ self.data2vec_audio = Data2VecAudioModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ self.num_labels = config.num_labels
+
+ self.init_weights()
+
+ def freeze_feature_extractor(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ warnings.warn(
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
+ FutureWarning,
+ )
+ self.freeze_feature_encoder()
+
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.data2vec_audio.feature_extractor._freeze_parameters()
+
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.data2vec_audio.parameters():
+ param.requires_grad = False
+
+ @auto_docstring
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, TokenClassifierOutput]:
+ r"""
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+ into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
+ (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
+ To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
+ into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.data2vec_audio(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class AMSoftmaxLoss(nn.Module):
+ def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
+ super().__init__()
+ self.scale = scale
+ self.margin = margin
+ self.num_labels = num_labels
+ self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
+ self.loss = nn.CrossEntropyLoss()
+
+ def forward(self, hidden_states, labels):
+ labels = labels.flatten()
+ weight = nn.functional.normalize(self.weight, dim=0)
+ hidden_states = nn.functional.normalize(hidden_states, dim=1)
+ cos_theta = torch.mm(hidden_states, weight)
+ psi = cos_theta - self.margin
+
+ onehot = nn.functional.one_hot(labels, self.num_labels)
+ logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
+ loss = self.loss(logits, labels)
+
+ return loss
+
+
+class TDNNLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
+ self.out_conv_dim = config.tdnn_dim[layer_id]
+ self.kernel_size = config.tdnn_kernel[layer_id]
+ self.dilation = config.tdnn_dilation[layer_id]
+
+ self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
+ self.activation = nn.ReLU()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if is_peft_available():
+ from peft.tuners.lora import LoraLayer
+
+ if is_peft_available():
+ if isinstance(self.kernel, LoraLayer):
+ warnings.warn(
+ "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
+ "You should exclude TDNNLayer from LoRA's target modules.",
+ )
+
+ # for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
+ hidden_states = hidden_states.transpose(1, 2)
+ weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
+ hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
+ hidden_states = hidden_states.transpose(1, 2)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+@auto_docstring(
+ custom_intro="""
+ Data2VecAudio Model with an XVector feature extraction head on top for tasks like Speaker Verification.
+ """
+)
+class Data2VecAudioForXVector(Data2VecAudioPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.data2vec_audio = Data2VecAudioModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
+
+ tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
+ self.tdnn = nn.ModuleList(tdnn_layers)
+
+ self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
+ self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
+
+ self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
+
+ self.init_weights()
+
+ def freeze_feature_extractor(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ warnings.warn(
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
+ FutureWarning,
+ )
+ self.freeze_feature_encoder()
+
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.data2vec_audio.feature_extractor._freeze_parameters()
+
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.data2vec_audio.parameters():
+ param.requires_grad = False
+
+ def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
+ """
+ Computes the output length of the TDNN layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return (input_length - kernel_size) // stride + 1
+
+ for kernel_size in self.config.tdnn_kernel:
+ input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
+
+ return input_lengths
+
+ @auto_docstring
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, XVectorOutput]:
+ r"""
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+ into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
+ (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
+ To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
+ into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.data2vec_audio(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+
+ for tdnn_layer in self.tdnn:
+ hidden_states = tdnn_layer(hidden_states)
+
+ # Statistic Pooling
+ if attention_mask is None:
+ mean_features = hidden_states.mean(dim=1)
+ std_features = hidden_states.std(dim=1)
+ else:
+ feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
+ tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
+ mean_features = []
+ std_features = []
+ for i, length in enumerate(tdnn_output_lengths):
+ mean_features.append(hidden_states[i, :length].mean(dim=0))
+ std_features.append(hidden_states[i, :length].std(dim=0))
+ mean_features = torch.stack(mean_features)
+ std_features = torch.stack(std_features)
+ statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
+
+ output_embeddings = self.feature_extractor(statistic_pooling)
+ logits = self.classifier(output_embeddings)
+
+ loss = None
+ if labels is not None:
+ loss = self.objective(logits, labels)
+
+ if not return_dict:
+ output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return XVectorOutput(
+ loss=loss,
+ logits=logits,
+ embeddings=output_embeddings,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "Data2VecAudioForAudioFrameClassification",
+ "Data2VecAudioForCTC",
+ "Data2VecAudioForSequenceClassification",
+ "Data2VecAudioForXVector",
+ "Data2VecAudioModel",
+ "Data2VecAudioPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_text.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..f866dd9144a627b938de67ed3d3f79816e849722
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_text.py
@@ -0,0 +1,1378 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Data2VecText model."""
+
+import math
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN, gelu
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import GenerationMixin
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import auto_docstring, logging
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_data2vec_text import Data2VecTextConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_HIDDEN_STATES_START_POSITION = 2
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->Data2VecText
+class Data2VecTextForTextEmbeddings(nn.Module):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+ """
+
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+ self.register_buffer(
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
+ )
+
+ # End copy
+ self.padding_idx = config.pad_token_id
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
+ )
+
+ def forward(
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
+ else:
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+ """
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+
+ Args:
+ inputs_embeds: torch.Tensor
+
+ Returns: torch.Tensor
+ """
+ input_shape = inputs_embeds.size()[:-1]
+ sequence_length = input_shape[1]
+
+ position_ids = torch.arange(
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
+ )
+ return position_ids.unsqueeze(0).expand(input_shape)
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Data2VecText
+class Data2VecTextSelfAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None, layer_idx=None):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = position_embedding_type or getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+ self.layer_idx = layer_idx
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor]:
+ batch_size, seq_length, _ = hidden_states.shape
+ query_layer = self.query(hidden_states)
+ query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
+ 1, 2
+ )
+
+ is_updated = False
+ is_cross_attention = encoder_hidden_states is not None
+ if past_key_values is not None:
+ if isinstance(past_key_values, EncoderDecoderCache):
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_layer from cache
+ curr_past_key_value = past_key_values.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_values.self_attention_cache
+ else:
+ curr_past_key_value = past_key_values
+
+ current_states = encoder_hidden_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_values is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_layer = curr_past_key_value.layers[self.layer_idx].keys
+ value_layer = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_layer = self.key(current_states)
+ key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
+ 1, 2
+ )
+ value_layer = self.value(current_states)
+ value_layer = value_layer.view(
+ batch_size, -1, self.num_attention_heads, self.attention_head_size
+ ).transpose(1, 2)
+
+ if past_key_values is not None:
+ # save all key/value_layer to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_layer, value_layer = curr_past_key_value.update(
+ key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
+ past_key_values.is_updated[self.layer_idx] = True
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
+ if past_key_values is not None:
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
+ -1, 1
+ )
+ else:
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in Data2VecTextModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ return context_layer, attention_probs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
+class Data2VecTextSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+DATA2VEC_TEXT_SELF_ATTENTION_CLASSES = {
+ "eager": Data2VecTextSelfAttention,
+}
+
+
+# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText,BERT->DATA2VEC_TEXT
+class Data2VecTextAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None, layer_idx=None):
+ super().__init__()
+ self.self = DATA2VEC_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
+ config,
+ position_embedding_type=position_embedding_type,
+ layer_idx=layer_idx,
+ )
+ self.output = Data2VecTextSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor]:
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate
+class Data2VecTextIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOutput
+class Data2VecTextOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText
+class Data2VecTextLayer(GradientCheckpointingLayer):
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = Data2VecTextAttention(config, layer_idx=layer_idx)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = Data2VecTextAttention(
+ config, position_embedding_type="absolute", layer_idx=layer_idx
+ )
+ self.intermediate = Data2VecTextIntermediate(config)
+ self.output = Data2VecTextOutput(config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor]:
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Data2VecText
+class Data2VecTextEncoder(nn.Module):
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([Data2VecTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if use_cache and self.config.is_decoder and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+
+ if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ past_key_values,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPooler
+class Data2VecTextPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+@auto_docstring
+class Data2VecTextPreTrainedModel(PreTrainedModel):
+ config: Data2VecTextConfig
+ base_model_prefix = "data2vec_text"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Data2VecTextForTextEmbeddings", "Data2VecTextLayer"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ if hasattr(module, "bias") and module.bias is not None:
+ module.bias.data.zero_()
+ if hasattr(module, "weight") and module.weight is not None:
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class Data2VecTextModel(Data2VecTextPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in *Attention is
+ all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
+ Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+
+ .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
+
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ r"""
+ add_pooling_layer (bool, *optional*, defaults to `True`):
+ Whether to add a pooling layer
+ """
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = Data2VecTextForTextEmbeddings(config)
+ self.encoder = Data2VecTextEncoder(config)
+
+ self.pooler = Data2VecTextPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ past_key_values_length = 0
+ if past_key_values is not None:
+ past_key_values_length = (
+ past_key_values[0][0].shape[-2]
+ if not isinstance(past_key_values, Cache)
+ else past_key_values.get_seq_length()
+ )
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Data2VecText Model with a `language modeling` head on top for CLM fine-tuning.
+ """
+)
+class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if not config.is_decoder:
+ logger.warning("If you want to use `Data2VecTextLMHeadModel` as a standalone, add `is_decoder=True.`")
+
+ self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
+ self.lm_head = Data2VecTextLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head.decoder = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Data2VecTextForCausalLM, Data2VecTextConfig
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/data2vec-text-base")
+ >>> config = Data2VecTextConfig.from_pretrained("facebook/data2vec-text-base")
+ >>> config.is_decoder = True
+ >>> model = Data2VecTextForCausalLM.from_pretrained("facebook/data2vec-text-base", config=config)
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.data2vec_text(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.lm_head(sequence_output)
+
+ lm_loss = None
+ if labels is not None:
+ lm_loss = self.loss_function(
+ prediction_scores,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+@auto_docstring
+class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel):
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `Data2VecTextForMaskedLM` make sure `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
+ self.lm_head = Data2VecTextLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head.decoder = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.data2vec_text(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ prediction_scores = self.lm_head(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+
+ labels = labels.to(prediction_scores.device)
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead with Roberta->Data2VecText
+class Data2VecTextLMHead(nn.Module):
+ """Data2VecText Head for masked language modeling."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+ self.decoder.bias = self.bias
+
+ def forward(self, features, **kwargs):
+ x = self.dense(features)
+ x = gelu(x)
+ x = self.layer_norm(x)
+
+ # project back to size of vocabulary with bias
+ x = self.decoder(x)
+
+ return x
+
+ def _tie_weights(self):
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
+ # For accelerate compatibility and to not break backward compatibility
+ if self.decoder.bias.device.type == "meta":
+ self.decoder.bias = self.bias
+ else:
+ self.bias = self.decoder.bias
+
+
+@auto_docstring(
+ custom_intro="""
+ Data2VecText Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """
+)
+class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
+ self.classifier = Data2VecTextClassificationHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.data2vec_text(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.data2vec_text = Data2VecTextModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, MultipleChoiceModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ flat_inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.data2vec_text(
+ flat_input_ids,
+ position_ids=flat_position_ids,
+ token_type_ids=flat_token_type_ids,
+ attention_mask=flat_attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+
+ labels = labels.to(reshaped_logits.device)
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class Data2VecTextForTokenClassification(Data2VecTextPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.data2vec_text(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+
+ labels = labels.to(logits.device)
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->Data2VecText
+class Data2VecTextClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(self, features, **kwargs):
+ x = features[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = torch.tanh(x)
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+@auto_docstring
+class Data2VecTextForQuestionAnswering(Data2VecTextPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, QuestionAnsweringModelOutput]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.data2vec_text(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+ are ignored. This is modified from fairseq's `utils.make_positions`.
+
+ Args:
+ x: torch.Tensor x:
+
+ Returns: torch.Tensor
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = input_ids.ne(padding_idx).int()
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
+ return incremental_indices.long() + padding_idx
+
+
+__all__ = [
+ "Data2VecTextForCausalLM",
+ "Data2VecTextForMaskedLM",
+ "Data2VecTextForMultipleChoice",
+ "Data2VecTextForQuestionAnswering",
+ "Data2VecTextForSequenceClassification",
+ "Data2VecTextForTokenClassification",
+ "Data2VecTextModel",
+ "Data2VecTextPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_vision.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..f214f8eb6a0bcefe5eb6a7072864dfbbe4b0b22a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_vision.py
@@ -0,0 +1,1348 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Data2VecVision model."""
+
+import collections.abc
+import math
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPooling,
+ ImageClassifierOutput,
+ SemanticSegmenterOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import auto_docstring, logging, torch_int
+from .configuration_data2vec_vision import Data2VecVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Class for outputs of [`Data2VecVisionModel`].
+ """
+)
+# Copied from transformers.models.beit.modeling_beit.BeitModelOutputWithPooling with Beit->Data2VecVision
+class Data2VecVisionModelOutputWithPooling(BaseModelOutputWithPooling):
+ r"""
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+ Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
+ *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
+ will be returned.
+ """
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Data2VecVision
+class Data2VecVisionDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return f"p={self.drop_prob}"
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitEmbeddings with Beit->Data2VecVision
+class Data2VecVisionEmbeddings(nn.Module):
+ """
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
+
+ """
+
+ def __init__(self, config: Data2VecVisionConfig) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ if config.use_mask_token:
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ else:
+ self.mask_token = None
+ self.patch_embeddings = Data2VecVisionPatchEmbeddings(config)
+ self.patch_size = config.patch_size
+ self.image_size = (
+ config.image_size
+ if isinstance(config.image_size, collections.abc.Iterable)
+ else (config.image_size, config.image_size)
+ )
+ num_patches = self.patch_embeddings.num_patches
+ if config.use_absolute_position_embeddings:
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
+ else:
+ self.position_embeddings = None
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ class_pos_embed = self.position_embeddings[:, :1]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ ) -> torch.Tensor:
+ if self.position_embeddings is not None and interpolate_pos_encoding is not None:
+ warnings.warn(
+ "`interpolate_pos_encoding` argument has no effect for BEiTEmbeddings, embeddings are always "
+ "interpolated to the input image size. The argument will be removed in transformers v4.51.0."
+ )
+
+ _, _, height, width = pixel_values.shape
+ embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
+ batch_size, seq_len, _ = embeddings.size()
+
+ if bool_masked_pos is not None:
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
+ # replace the masked visual tokens by mask_tokens
+ w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1 - w) + mask_tokens * w
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ if self.position_embeddings is not None:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings, (patch_height, patch_width)
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitPatchEmbeddings with Beit->Data2VecVision
+class Data2VecVisionPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+ self.patch_shape = patch_shape
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+
+ embeddings = self.projection(pixel_values)
+ patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
+ embeddings = embeddings.flatten(2).transpose(1, 2)
+
+ return embeddings, (patch_height, patch_width)
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitSelfAttention with Beit->Data2VecVision
+class Data2VecVisionSelfAttention(nn.Module):
+ def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ self.has_relative_position_bias = bool(window_size)
+ if self.has_relative_position_bias:
+ self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ relative_position_bias: Optional[torch.Tensor] = None,
+ interpolate_pos_encoding: bool = False,
+ resolution: Optional[tuple[int]] = None,
+ ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
+ batch_size, seq_length, _ = hidden_states.shape
+ query_layer = (
+ self.query(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+ key_layer = (
+ self.key(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+ value_layer = (
+ self.value(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # Add relative position bias if present.
+ if self.has_relative_position_bias:
+ height, width = resolution
+ window_size = (height // self.config.patch_size, width // self.config.patch_size)
+ attention_scores = attention_scores + self.relative_position_bias(
+ window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
+ )
+
+ # Add shared relative position bias if provided.
+ if relative_position_bias is not None:
+ attention_scores = attention_scores + relative_position_bias
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitSdpaSelfAttention with Beit->Data2VecVision
+class Data2VecVisionSdpaSelfAttention(Data2VecVisionSelfAttention):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ relative_position_bias: Optional[torch.Tensor] = None,
+ interpolate_pos_encoding: bool = False,
+ resolution: Optional[tuple[int]] = None,
+ ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
+ if output_attentions or head_mask is not None:
+ logger.warning_once(
+ "`Data2VecVisionSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not "
+ "support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, "
+ "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ relative_position_bias=relative_position_bias,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ resolution=resolution,
+ )
+
+ batch_size, seq_length, _ = hidden_states.shape
+ query_layer = (
+ self.query(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+ key_layer = (
+ self.key(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+ value_layer = (
+ self.value(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+
+ attn_bias = None
+ if self.has_relative_position_bias:
+ height, width = resolution
+ window_size = (height // self.config.patch_size, width // self.config.patch_size)
+ attn_bias = self.relative_position_bias(
+ window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
+ )
+
+ # Add shared relative position bias if provided.
+ if relative_position_bias is not None:
+ if attn_bias is None:
+ attn_bias = relative_position_bias
+ else:
+ attn_bias += relative_position_bias
+
+ scaling = 1 / math.sqrt(self.attention_head_size)
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
+ query_layer,
+ key_layer,
+ value_layer,
+ attn_mask=attn_bias,
+ dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0,
+ is_causal=False,
+ scale=scaling,
+ )
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+ return context_layer, None
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitSelfOutput with Beit->Data2VecVision
+class Data2VecVisionSelfOutput(nn.Module):
+ """
+ The residual connection is defined in Data2VecVisionLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: Data2VecVisionConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+DATA2VEC_VISION_SELF_ATTENTION_CLASSES = {
+ "eager": Data2VecVisionSelfAttention,
+ "sdpa": Data2VecVisionSdpaSelfAttention,
+}
+
+
+# Copied from tests.models.beit.modeling_beit.BeitAttention with Beit->Data2VecVision, BEIT->DATA2VEC_VISION
+class Data2VecVisionAttention(nn.Module):
+ def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
+ super().__init__()
+ self.attention = DATA2VEC_VISION_SELF_ATTENTION_CLASSES[config._attn_implementation](
+ config, window_size=window_size
+ )
+ self.output = Data2VecVisionSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
+ interpolate_pos_encoding: bool = False,
+ resolution: Optional[tuple[int]] = None,
+ ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
+ self_outputs = self.attention(
+ hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution
+ )
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitIntermediate with Beit->Data2VecVision
+class Data2VecVisionIntermediate(nn.Module):
+ def __init__(self, config: Data2VecVisionConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitOutput with Beit->Data2VecVision
+class Data2VecVisionOutput(nn.Module):
+ def __init__(self, config: Data2VecVisionConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitLayer with Beit->Data2VecVision,BEiT->Data2VecVision
+class Data2VecVisionLayer(GradientCheckpointingLayer):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(
+ self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0
+ ) -> None:
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = Data2VecVisionAttention(config, window_size=window_size)
+ self.intermediate = Data2VecVisionIntermediate(config)
+ self.output = Data2VecVisionOutput(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.drop_path = Data2VecVisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ init_values = config.layer_scale_init_value
+ if init_values > 0:
+ self.lambda_1 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
+ self.lambda_2 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
+ else:
+ self.lambda_1, self.lambda_2 = None, None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ relative_position_bias: Optional[torch.Tensor] = None,
+ interpolate_pos_encoding: bool = False,
+ resolution: Optional[tuple[int, int]] = None,
+ ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
+ self_attention_outputs = self.attention(
+ self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention
+ head_mask,
+ output_attentions=output_attentions,
+ relative_position_bias=relative_position_bias,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ resolution=resolution,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # apply lambda_1 if present
+ if self.lambda_1 is not None:
+ attention_output = self.lambda_1 * attention_output
+
+ # first residual connection
+ hidden_states = self.drop_path(attention_output) + hidden_states
+
+ # in Data2VecVision, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+
+ layer_output = self.intermediate(layer_output)
+ layer_output = self.output(layer_output)
+
+ if self.lambda_2 is not None:
+ layer_output = self.lambda_2 * layer_output
+
+ # second residual connection
+ layer_output = self.drop_path(layer_output) + hidden_states
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitRelativePositionBias with Beit->Data2VecVision
+class Data2VecVisionRelativePositionBias(nn.Module):
+ def __init__(self, config: Data2VecVisionConfig, window_size: tuple) -> None:
+ super().__init__()
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, config.num_attention_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ @compile_compatible_method_lru_cache(maxsize=10)
+ def generate_relative_position_index(self, window_size: tuple[int, int]) -> torch.Tensor:
+ """
+ This method creates the relative position index, modified to support arbitrary window sizes,
+ as introduced in [MiDaS v3.1](https://huggingface.co/papers/2307.14460).
+ """
+ num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ # cls to token & token 2 cls & cls to cls
+ # get pair-wise relative position index for each token inside the window
+ window_area = window_size[0] * window_size[1]
+ grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij")
+ coords = torch.stack(grid) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = num_relative_distance - 3
+ relative_position_index[0:, 0] = num_relative_distance - 2
+ relative_position_index[0, 0] = num_relative_distance - 1
+ return relative_position_index
+
+ def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor:
+ """
+ Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
+ """
+ old_height = 2 * self.window_size[0] - 1
+ old_width = 2 * self.window_size[1] - 1
+
+ new_height = 2 * window_size[0] - 1
+ new_width = 2 * window_size[1] - 1
+
+ old_relative_position_bias_table = self.relative_position_bias_table
+
+ old_num_relative_distance = self.num_relative_distance
+ new_num_relative_distance = new_height * new_width + 3
+
+ old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
+
+ old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
+ new_sub_table = nn.functional.interpolate(
+ old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
+ )
+ new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
+
+ new_relative_position_bias_table = torch.cat(
+ [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
+ )
+
+ relative_position_index = self.generate_relative_position_index(window_size)
+ relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)]
+
+ # patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
+ relative_position_bias = relative_position_bias.view(
+ window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
+ )
+ # num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
+
+ if interpolate_pos_encoding:
+ relative_position_bias = nn.functional.interpolate(
+ relative_position_bias.unsqueeze(1),
+ size=(dim_size, dim_size),
+ mode="bilinear",
+ align_corners=False,
+ ).squeeze(1)
+
+ return relative_position_bias.unsqueeze(0)
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision
+class Data2VecVisionEncoder(nn.Module):
+ def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
+ super().__init__()
+ self.config = config
+ self.has_relative_position_bias = config.use_shared_relative_position_bias
+ if self.has_relative_position_bias:
+ self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
+
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
+ self.layer = nn.ModuleList(
+ [
+ Data2VecVisionLayer(
+ config,
+ window_size=window_size if config.use_relative_position_bias else None,
+ drop_path_rate=dpr[i],
+ )
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ interpolate_pos_encoding: bool = False,
+ resolution: Optional[tuple[int, int]] = None,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.has_relative_position_bias:
+ height, width = resolution
+ window_size = (height // self.config.patch_size, width // self.config.patch_size)
+ relative_position_bias = self.relative_position_bias(
+ window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
+ )
+ else:
+ relative_position_bias = None
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states,
+ head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ relative_position_bias=relative_position_bias,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ resolution=resolution,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+@auto_docstring
+# Copied from transformers.models.beit.modeling_beit.BeitPreTrainedModel with Beit->Data2VecVision,beit->data2vec_vision
+class Data2VecVisionPreTrainedModel(PreTrainedModel):
+ config: Data2VecVisionConfig
+ base_model_prefix = "data2vec_vision"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Data2VecVisionLayer"]
+ _keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Data2VecVisionEmbeddings):
+ module.cls_token.data.zero_()
+ if module.mask_token is not None:
+ module.mask_token.data.zero_()
+ if module.position_embeddings is not None:
+ module.position_embeddings.data.zero_()
+ elif isinstance(module, Data2VecVisionRelativePositionBias):
+ module.relative_position_bias_table.data.zero_()
+ elif isinstance(module, Data2VecVisionLayer):
+ if module.lambda_1 is not None:
+ module.lambda_1.data.fill_(self.config.layer_scale_init_value)
+ module.lambda_2.data.fill_(self.config.layer_scale_init_value)
+
+
+@auto_docstring
+# Copied from transformers.models.beit.modeling_beit.BeitModel with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,True->False
+class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
+ def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = False) -> None:
+ r"""
+ add_pooling_layer (bool, *optional*, defaults to `False`):
+ Whether to add a pooling layer
+ """
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = Data2VecVisionEmbeddings(config)
+ self.encoder = Data2VecVisionEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape)
+
+ self.layernorm = (
+ nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ )
+ self.pooler = Data2VecVisionPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, Data2VecVisionModelOutputWithPooling]:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
+ resolution = pixel_values.shape[2:]
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ resolution=resolution,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+ return head_outputs + encoder_outputs[1:]
+
+ return Data2VecVisionModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitPooler with Beit->Data2VecVision
+class Data2VecVisionPooler(nn.Module):
+ def __init__(self, config: Data2VecVisionConfig) -> None:
+ super().__init__()
+ self.layernorm = (
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if self.layernorm is not None:
+ # Mean pool the final hidden states of the patch tokens
+ patch_tokens = hidden_states[:, 1:, :]
+ pooled_output = self.layernorm(patch_tokens.mean(1))
+ else:
+ # Pool by simply taking the final hidden state of the [CLS] token
+ pooled_output = hidden_states[:, 0]
+
+ return pooled_output
+
+
+@auto_docstring(
+ custom_intro="""
+ Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of
+ the final hidden states of the patch tokens) e.g. for ImageNet.
+ """
+)
+# Copied from transformers.models.beit.modeling_beit.BeitForImageClassification with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,beit->data2vec_vision
+class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel):
+ def __init__(self, config: Data2VecVisionConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=True)
+
+ # Classifier head
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ outputs = self.data2vec_vision(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitConvModule with Beit->Data2VecVision
+class Data2VecVisionConvModule(nn.Module):
+ """
+ A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
+ layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, tuple[int, int]],
+ padding: Union[int, tuple[int, int], str] = 0,
+ bias: bool = False,
+ dilation: Union[int, tuple[int, int]] = 1,
+ ) -> None:
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ bias=bias,
+ dilation=dilation,
+ )
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.activation = nn.ReLU()
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ output = self.conv(input)
+ output = self.bn(output)
+ output = self.activation(output)
+
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingBlock with Beit->Data2VecVision
+class Data2VecVisionPyramidPoolingBlock(nn.Module):
+ def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
+ super().__init__()
+ self.layers = [
+ nn.AdaptiveAvgPool2d(pool_scale),
+ Data2VecVisionConvModule(in_channels, channels, kernel_size=1),
+ ]
+ for i, layer in enumerate(self.layers):
+ self.add_module(str(i), layer)
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ hidden_state = input
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingModule with Beit->Data2VecVision
+class Data2VecVisionPyramidPoolingModule(nn.Module):
+ """
+ Pyramid Pooling Module (PPM) used in PSPNet.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ align_corners (bool): align_corners argument of F.interpolate.
+
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+ """
+
+ def __init__(self, pool_scales: tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None:
+ super().__init__()
+ self.pool_scales = pool_scales
+ self.align_corners = align_corners
+ self.in_channels = in_channels
+ self.channels = channels
+ self.blocks = []
+ for i, pool_scale in enumerate(pool_scales):
+ block = Data2VecVisionPyramidPoolingBlock(
+ pool_scale=pool_scale, in_channels=in_channels, channels=channels
+ )
+ self.blocks.append(block)
+ self.add_module(str(i), block)
+
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
+ ppm_outs = []
+ for ppm in self.blocks:
+ ppm_out = ppm(x)
+ upsampled_ppm_out = nn.functional.interpolate(
+ ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
+ )
+ ppm_outs.append(upsampled_ppm_out)
+ return ppm_outs
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitUperHead with Beit->Data2VecVision
+class Data2VecVisionUperHead(nn.Module):
+ """
+ Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
+ [UPerNet](https://huggingface.co/papers/1807.10221).
+
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+ """
+
+ def __init__(self, config: Data2VecVisionConfig) -> None:
+ super().__init__()
+
+ self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
+ self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
+ self.channels = config.hidden_size
+ self.align_corners = False
+ self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
+
+ # PSP Module
+ self.psp_modules = Data2VecVisionPyramidPoolingModule(
+ self.pool_scales,
+ self.in_channels[-1],
+ self.channels,
+ align_corners=self.align_corners,
+ )
+ self.bottleneck = Data2VecVisionConvModule(
+ self.in_channels[-1] + len(self.pool_scales) * self.channels,
+ self.channels,
+ kernel_size=3,
+ padding=1,
+ )
+ # FPN Module
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+ for in_channels in self.in_channels[:-1]: # skip the top layer
+ l_conv = Data2VecVisionConvModule(in_channels, self.channels, kernel_size=1)
+ fpn_conv = Data2VecVisionConvModule(self.channels, self.channels, kernel_size=3, padding=1)
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ self.fpn_bottleneck = Data2VecVisionConvModule(
+ len(self.in_channels) * self.channels,
+ self.channels,
+ kernel_size=3,
+ padding=1,
+ )
+
+ def psp_forward(self, inputs):
+ x = inputs[-1]
+ psp_outs = [x]
+ psp_outs.extend(self.psp_modules(x))
+ psp_outs = torch.cat(psp_outs, dim=1)
+ output = self.bottleneck(psp_outs)
+
+ return output
+
+ def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ # build laterals
+ laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
+
+ laterals.append(self.psp_forward(encoder_hidden_states))
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
+ laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
+ )
+
+ # build outputs
+ fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
+ # append psp feature
+ fpn_outs.append(laterals[-1])
+
+ for i in range(used_backbone_levels - 1, 0, -1):
+ fpn_outs[i] = nn.functional.interpolate(
+ fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
+ )
+ fpn_outs = torch.cat(fpn_outs, dim=1)
+ output = self.fpn_bottleneck(fpn_outs)
+ output = self.classifier(output)
+
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitFCNHead with Beit->Data2VecVision
+class Data2VecVisionFCNHead(nn.Module):
+ """
+ Fully Convolution Networks for Semantic Segmentation. This head is implemented of
+ [FCNNet](https://huggingface.co/papers/1411.4038>).
+
+ Args:
+ config (Data2VecVisionConfig): Configuration.
+ in_channels
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
+ dilation (int): The dilation rate for convs in the head. Default: 1.
+
+
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+ """
+
+ def __init__(
+ self,
+ config: Data2VecVisionConfig,
+ in_index: int = 2,
+ kernel_size: int = 3,
+ dilation: Union[int, tuple[int, int]] = 1,
+ ) -> None:
+ super().__init__()
+ self.in_channels = config.hidden_size
+ self.channels = config.auxiliary_channels
+ self.num_convs = config.auxiliary_num_convs
+ self.concat_input = config.auxiliary_concat_input
+ self.in_index = in_index
+
+ conv_padding = (kernel_size // 2) * dilation
+ convs = []
+ convs.append(
+ Data2VecVisionConvModule(
+ self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
+ )
+ )
+ for i in range(self.num_convs - 1):
+ convs.append(
+ Data2VecVisionConvModule(
+ self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
+ )
+ )
+ if self.num_convs == 0:
+ self.convs = nn.Identity()
+ else:
+ self.convs = nn.Sequential(*convs)
+ if self.concat_input:
+ self.conv_cat = Data2VecVisionConvModule(
+ self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
+ )
+
+ self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
+
+ def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ # just take the relevant feature maps
+ hidden_states = encoder_hidden_states[self.in_index]
+ output = self.convs(hidden_states)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
+ output = self.classifier(output)
+ return output
+
+
+@auto_docstring
+# Copied from transformers.models.beit.modeling_beit.BeitForSemanticSegmentation with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,microsoft/beit-base-finetuned-ade-640-640->facebook/data2vec-vision-base,beit->data2vec_vision
+class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel):
+ def __init__(self, config: Data2VecVisionConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=False)
+
+ # FPNs
+ if len(self.config.out_indices) != 4:
+ raise ValueError(
+ "Data2VecVisionForSemanticSegmentation requires config.out_indices to be a list of 4 integers, "
+ "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
+ "a base-sized architecture."
+ )
+ self.fpn1 = nn.Sequential(
+ nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
+ nn.BatchNorm2d(config.hidden_size),
+ nn.GELU(),
+ nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
+ )
+ self.fpn2 = nn.Sequential(
+ nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
+ )
+ self.fpn3 = nn.Identity()
+ self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
+
+ # Semantic segmentation head(s)
+ self.decode_head = Data2VecVisionUperHead(config)
+ self.auxiliary_head = Data2VecVisionFCNHead(config) if config.use_auxiliary_head else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def compute_loss(self, logits, auxiliary_logits, labels):
+ # upsample logits to the images' original size
+ upsampled_logits = nn.functional.interpolate(
+ logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
+ )
+ if auxiliary_logits is not None:
+ upsampled_auxiliary_logits = nn.functional.interpolate(
+ auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
+ )
+ # compute weighted loss
+ loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
+ main_loss = loss_fct(upsampled_logits, labels)
+ loss = main_loss
+ if auxiliary_logits is not None:
+ auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
+ loss += self.config.auxiliary_loss_weight * auxiliary_loss
+
+ return loss
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, SemanticSegmenterOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, Data2VecVisionForSemanticSegmentation
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base")
+ >>> model = Data2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> # logits are of shape (batch_size, num_labels, height, width)
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if labels is not None and self.config.num_labels == 1:
+ raise ValueError("The number of labels should be greater than one")
+
+ outputs = self.data2vec_vision(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True, # we need the intermediate hidden states
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ # only keep certain features, and reshape
+ # note that we do +1 as the encoder_hidden_states also includes the initial embeddings
+ features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
+ batch_size = pixel_values.shape[0]
+ patch_resolution = self.config.image_size // self.config.patch_size
+ features = [
+ x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
+ ]
+
+ # apply FPNs
+ ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
+ for i in range(len(features)):
+ features[i] = ops[i](features[i])
+
+ logits = self.decode_head(features)
+
+ auxiliary_logits = None
+ if self.auxiliary_head is not None:
+ auxiliary_logits = self.auxiliary_head(features)
+
+ loss = None
+ if labels is not None:
+ loss = self.compute_loss(logits, auxiliary_logits, labels)
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (logits,) + outputs[1:]
+ else:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SemanticSegmenterOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "Data2VecVisionForImageClassification",
+ "Data2VecVisionForSemanticSegmentation",
+ "Data2VecVisionModel",
+ "Data2VecVisionPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_tf_data2vec_vision.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_tf_data2vec_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fa0fe1f811ee1748956c2481d28b0ad1ef516e0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_tf_data2vec_vision.py
@@ -0,0 +1,1723 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 Data2Vec Vision model."""
+
+from __future__ import annotations
+
+import collections.abc
+import math
+from dataclasses import dataclass
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutput,
+ TFBaseModelOutputWithPooling,
+ TFSemanticSegmenterOutput,
+ TFSequenceClassifierOutput,
+)
+from ...modeling_tf_utils import (
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list, stable_softmax
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_data2vec_vision import Data2VecVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "Data2VecVisionConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/data2vec-vision-base"
+_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote"
+
+
+@dataclass
+class TFData2VecVisionModelOutputWithPooling(TFBaseModelOutputWithPooling):
+ """
+ Class for outputs of [`TFData2VecVisionModel`].
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
+ Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
+ *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
+ will be returned.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: tf.Tensor | None = None
+ pooler_output: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+class TFData2VecVisionDropPath(keras.layers.Layer):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ References:
+ (1) github.com:rwightman/pytorch-image-models
+ """
+
+ def __init__(self, drop_path, **kwargs):
+ super().__init__(**kwargs)
+ self.drop_path = drop_path
+
+ def call(self, x, training=None):
+ if training:
+ keep_prob = 1 - self.drop_path
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
+ random_tensor = tf.floor(random_tensor)
+ return (x / keep_prob) * random_tensor
+ return x
+
+
+class TFData2VecVisionEmbeddings(keras.layers.Layer):
+ """
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
+
+ """
+
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+
+ self.patch_embeddings = TFData2VecVisionPatchEmbeddings(config, name="patch_embeddings")
+ self.num_patches = self.patch_embeddings.num_patches
+ self.config = config
+
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+
+ def build(self, input_shape=None):
+ self.cls_token = self.add_weight(
+ shape=(1, 1, self.config.hidden_size),
+ initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
+ trainable=True,
+ name="cls_token",
+ )
+ if self.config.use_mask_token:
+ self.mask_token = self.add_weight(
+ shape=(1, 1, self.config.hidden_size),
+ initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
+ trainable=True,
+ name="mask_token",
+ )
+ else:
+ self.mask_token = None
+
+ if self.config.use_absolute_position_embeddings:
+ self.position_embeddings = self.add_weight(
+ shape=(1, self.num_patches + 1, self.config.hidden_size),
+ initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
+ trainable=True,
+ name="position_embeddings",
+ )
+ else:
+ self.position_embeddings = None
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "patch_embeddings", None) is not None:
+ with tf.name_scope(self.patch_embeddings.name):
+ self.patch_embeddings.build(None)
+
+ def call(self, pixel_values: tf.Tensor, bool_masked_pos: tf.Tensor | None = None) -> tf.Tensor:
+ embeddings = self.patch_embeddings(pixel_values)
+ batch_size, seq_len, projection_dim = shape_list(embeddings)
+
+ cls_tokens = tf.tile(self.cls_token, (batch_size, 1, 1))
+
+ if bool_masked_pos is not None:
+ mask_tokens = tf.broadcast_to(self.mask_token, (batch_size, seq_len, projection_dim))
+ # replace the masked visual tokens by mask_tokens
+ w = bool_masked_pos[..., None]
+ w = tf.cast(w, mask_tokens.dtype)
+ # since TF doesn't support eager tensor assignment
+ embeddings = embeddings * (1 - w) + mask_tokens * w
+
+ embeddings = tf.concat([cls_tokens, embeddings], axis=1)
+ if self.position_embeddings is not None:
+ embeddings = embeddings + self.position_embeddings
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class TFData2VecVisionPatchEmbeddings(keras.layers.Layer):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ self.patch_shape = patch_shape
+ self.num_channels = num_channels
+
+ self.projection = keras.layers.Conv2D(
+ filters=hidden_size,
+ kernel_size=patch_size,
+ strides=patch_size,
+ padding="valid",
+ data_format="channels_last",
+ kernel_initializer="glorot_uniform", # following torch.nn.Linear
+ bias_initializer="zeros",
+ name="projection",
+ )
+
+ def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
+ batch_size, num_channels, height, width = shape_list(pixel_values)
+ if tf.executing_eagerly():
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the"
+ " configuration."
+ )
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
+
+ # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
+ # So change the input format from `NCHW` to `NHWC`.
+ # shape = (batch_size, in_height, in_width, in_channels=num_channels)
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+
+ projection = self.projection(pixel_values)
+
+ # Change the 2D spatial dimensions to a single temporal dimension.
+ # shape = (batch_size, num_patches, out_channels=embed_dim)
+ num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
+
+ return tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "projection", None) is not None:
+ with tf.name_scope(self.projection.name):
+ self.projection.build([None, None, None, self.num_channels])
+
+
+class TFData2VecVisionSelfAttention(keras.layers.Layer):
+ def __init__(self, config: Data2VecVisionConfig, window_size: tuple | None = None, **kwargs):
+ super().__init__(**kwargs)
+
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number "
+ f"of attention heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
+
+ self.query = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
+ )
+ self.key = keras.layers.Dense(
+ units=self.all_head_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="key",
+ use_bias=False,
+ )
+ self.value = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
+ )
+ self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
+
+ if window_size:
+ self.relative_position_bias = TFData2VecVisionRelativePositionBias(
+ config, window_size=window_size, name="relative_position_bias"
+ )
+ else:
+ self.relative_position_bias = None
+ self.config = config
+
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
+
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ relative_position_bias: TFData2VecVisionRelativePositionBias | None = None,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ batch_size = shape_list(hidden_states)[0]
+ mixed_query_layer = self.query(inputs=hidden_states)
+ mixed_key_layer = self.key(inputs=hidden_states)
+ mixed_value_layer = self.value(inputs=hidden_states)
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
+ key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
+ value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ # (batch size, num_heads, seq_len_q, seq_len_k)
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+ attention_scores = attention_scores / self.sqrt_att_head_size
+
+ # Add relative position bias if present.
+ if self.relative_position_bias is not None:
+ # Passing `0.0` to the `relative_position_bias()` layer because otherwise Keras
+ # might complain about `Layer.call()` not being invoked properly. In this case this input
+ # i.e., 0.0 is not going to be used in any calculations so we're safe.
+ attention_scores = attention_scores + self.relative_position_bias(0.0)[None, ...]
+
+ # Add shared relative position bias if provided.
+ if relative_position_bias is not None:
+ attention_scores = attention_scores + relative_position_bias
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(inputs=attention_probs, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = tf.multiply(attention_probs, head_mask)
+
+ attention_output = tf.matmul(attention_probs, value_layer)
+ attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
+
+ # (batch_size, seq_len_q, all_head_size)
+ attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "query", None) is not None:
+ with tf.name_scope(self.query.name):
+ self.query.build([None, None, self.config.hidden_size])
+ if getattr(self, "key", None) is not None:
+ with tf.name_scope(self.key.name):
+ self.key.build([None, None, self.config.hidden_size])
+ if getattr(self, "value", None) is not None:
+ with tf.name_scope(self.value.name):
+ self.value.build([None, None, self.config.hidden_size])
+ if getattr(self, "relative_position_bias", None) is not None:
+ with tf.name_scope(self.relative_position_bias.name):
+ self.relative_position_bias.build(None)
+
+
+class TFData2VecVisionSelfOutput(keras.layers.Layer):
+ """
+ The residual connection is defined in TFData2VecVisionLayer instead of here (as is the case with other models), due
+ to the layernorm applied before each block.
+ """
+
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, gamma=None, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFData2VecVisionAttention(keras.layers.Layer):
+ def __init__(self, config: Data2VecVisionConfig, window_size: tuple | None = None, **kwargs):
+ super().__init__(**kwargs)
+
+ self.attention = TFData2VecVisionSelfAttention(config, window_size=window_size, name="attention")
+ self.dense_output = TFData2VecVisionSelfOutput(config, name="output")
+
+ def prune_heads(self, heads):
+ raise NotImplementedError
+
+ def call(
+ self,
+ input_tensor: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ relative_position_bias: TFData2VecVisionRelativePositionBias | None = None,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ self_outputs = self.attention(
+ hidden_states=input_tensor,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ relative_position_bias=relative_position_bias,
+ training=training,
+ )
+ attention_output = self.dense_output(
+ hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
+ )
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "dense_output", None) is not None:
+ with tf.name_scope(self.dense_output.name):
+ self.dense_output.build(None)
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->Data2VecVision
+class TFData2VecVisionIntermediate(keras.layers.Layer):
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
+ else:
+ self.intermediate_act_fn = config.hidden_act
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFData2VecVisionOutput(keras.layers.Layer):
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.intermediate_size])
+
+
+class TFData2VecVisionLayer(keras.layers.Layer):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(
+ self, config: Data2VecVisionConfig, window_size: tuple | None = None, drop_path_rate: float = 0.0, **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.config = config
+
+ self.attention = TFData2VecVisionAttention(config, window_size=window_size, name="attention")
+ self.intermediate = TFData2VecVisionIntermediate(config, name="intermediate")
+ self.data2vec_output = TFData2VecVisionOutput(config, name="output")
+
+ self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
+ self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
+ # Using `layers.Activation` instead of `tf.identity` to better control `training`
+ # behaviour.
+ self.drop_path = (
+ TFData2VecVisionDropPath(drop_path_rate, name="drop_path")
+ if drop_path_rate > 0.0
+ else keras.layers.Activation("linear", name="drop_path")
+ )
+ self.init_values = config.layer_scale_init_value
+
+ def build(self, input_shape: tf.TensorShape = None):
+ if self.init_values > 0:
+ self.lambda_1 = self.add_weight(
+ shape=(self.config.hidden_size),
+ initializer="ones",
+ trainable=True,
+ name="lambda_1",
+ )
+ self.lambda_2 = self.add_weight(
+ shape=(self.config.hidden_size),
+ initializer="ones",
+ trainable=True,
+ name="lambda_2",
+ )
+ self.lambda_1.assign(self.init_values * tf.ones(self.config.hidden_size))
+ self.lambda_2.assign(self.init_values * tf.ones(self.config.hidden_size))
+ else:
+ self.lambda_1, self.lambda_2 = None, None
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "intermediate", None) is not None:
+ with tf.name_scope(self.intermediate.name):
+ self.intermediate.build(None)
+ if getattr(self, "data2vec_output", None) is not None:
+ with tf.name_scope(self.data2vec_output.name):
+ self.data2vec_output.build(None)
+ if getattr(self, "layernorm_before", None) is not None:
+ with tf.name_scope(self.layernorm_before.name):
+ self.layernorm_before.build([None, None, self.config.hidden_size])
+ if getattr(self, "layernorm_after", None) is not None:
+ with tf.name_scope(self.layernorm_after.name):
+ self.layernorm_after.build([None, None, self.config.hidden_size])
+ if getattr(self, "drop_path", None) is not None:
+ with tf.name_scope(self.drop_path.name):
+ self.drop_path.build(None)
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ relative_position_bias: TFData2VecVisionRelativePositionBias | None = None,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ self_attention_outputs = self.attention(
+ # in Data2VecVision, layernorm is applied before self-attention
+ input_tensor=self.layernorm_before(inputs=hidden_states),
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ relative_position_bias=relative_position_bias,
+ training=training,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # apply lambda_1 if present
+ if self.lambda_1 is not None:
+ attention_output = self.lambda_1 * attention_output
+
+ # first residual connection
+ hidden_states = self.drop_path(attention_output) + hidden_states
+
+ # in Data2VecVision, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+
+ layer_output = self.intermediate(layer_output)
+ layer_output = self.data2vec_output(layer_output)
+
+ if self.lambda_2 is not None:
+ layer_output = self.lambda_2 * layer_output
+
+ # second residual connection
+ layer_output = self.drop_path(layer_output) + hidden_states
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+# Taken and modified from here:
+# https://github.com/leondgarse/keras_cv_attention_models/blob/main/keras_cv_attention_models/beit/beit.py#L28
+class TFData2VecVisionRelativePositionBias(keras.layers.Layer):
+ def __init__(self, config: Data2VecVisionConfig, window_size: tuple, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.config = config
+
+ self.window_size = window_size
+ # +3 for cls_token_pos_len
+ # window_size can be something like (14, 14)
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+
+ self.relative_position_index = self.get_position_index()
+
+ def build(self, input_shape):
+ self.relative_position_bias_table = self.add_weight(
+ shape=(self.num_relative_distance, self.config.num_attention_heads),
+ initializer="zeros",
+ trainable=True,
+ name="relative_position_bias_table",
+ ) # [2*Wh-1 * 2*Ww-1, nH]
+ # cls to token & token 2 cls & cls to cls
+
+ super().build(input_shape)
+
+ def get_position_index(self):
+ # get pair-wise relative position index for each token inside the window
+ xx, yy = tf.meshgrid(range(self.window_size[0]), range(self.window_size[1]))
+ coords = tf.stack([yy, xx], axis=0) # [2, Wh, Ww]
+ coords_flatten = tf.reshape(coords, [2, -1]) # [2, Wh*Ww]
+
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Wh*Ww, Wh*Ww]
+ relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0]) # [Wh*Ww, Wh*Ww, 2]
+
+ xx = (relative_coords[:, :, 0] + self.window_size[0] - 1) * (2 * self.window_size[1] - 1)
+ yy = relative_coords[:, :, 1] + self.window_size[1] - 1
+ relative_coords = tf.stack([xx, yy], axis=-1)
+
+ relative_position_index = tf.reduce_sum(relative_coords, axis=-1) # [Wh*Ww, Wh*Ww]
+
+ top = tf.ones((1, relative_position_index.shape[1]), dtype=relative_position_index.dtype) * (
+ self.num_relative_distance - 3
+ )
+ left = tf.ones((relative_position_index.shape[0], 1), dtype=relative_position_index.dtype) * (
+ self.num_relative_distance - 2
+ )
+ corner = tf.ones((1, 1), dtype=relative_position_index.dtype) * (self.num_relative_distance - 1)
+
+ left_corner = tf.concat([corner, left], axis=0)
+ relative_position_index = tf.concat([top, relative_position_index], axis=0)
+ relative_position_index = tf.concat([left_corner, relative_position_index], axis=1) # [Wh*Ww + 1, Wh*Ww + 1]
+ return relative_position_index
+
+ def call(self, inputs=None) -> tf.Tensor:
+ relative_position_bias = tf.gather(self.relative_position_bias_table, self.relative_position_index, axis=0)
+ return tf.transpose(relative_position_bias, [2, 0, 1])
+
+
+class TFData2VecVisionEncoder(keras.layers.Layer):
+ def __init__(self, config: Data2VecVisionConfig, window_size: tuple | None = None, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ if config.use_shared_relative_position_bias:
+ self.relative_position_bias = TFData2VecVisionRelativePositionBias(
+ config, window_size=window_size, name="relative_position_bias"
+ )
+ else:
+ self.relative_position_bias = None
+
+ # stochastic depth decay rule
+ dpr = list(tf.linspace(0.0, config.drop_path_rate, config.num_hidden_layers))
+ self.layer = [
+ TFData2VecVisionLayer(
+ config,
+ window_size=window_size if config.use_relative_position_bias else None,
+ drop_path_rate=dpr[i],
+ name=f"layer_._{i}",
+ )
+ for i in range(config.num_hidden_layers)
+ ]
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ head_mask: tf.Tensor | None = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> tuple | TFBaseModelOutput:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ # Passing `0.0` to the `relative_position_bias()` layer because otherwise Keras
+ # might complain about `Layer.call()` not being invoked properly. In this case this input
+ # i.e., 0.0 is not going to be used in any calculations so we're safe.
+ relative_position_bias = (
+ self.relative_position_bias(0.0) if self.relative_position_bias is not None else None
+ )
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "relative_position_bias", None) is not None:
+ with tf.name_scope(self.relative_position_bias.name):
+ self.relative_position_bias.build(None)
+ if getattr(self, "layer", None) is not None:
+ for layer in self.layer:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+@keras_serializable
+class TFData2VecVisionMainLayer(keras.layers.Layer):
+ config_class = Data2VecVisionConfig
+
+ def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = True, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.add_pooling_layer = add_pooling_layer
+
+ self.embeddings = TFData2VecVisionEmbeddings(config, name="embeddings")
+ self.encoder = TFData2VecVisionEncoder(
+ config, window_size=self.embeddings.patch_embeddings.patch_shape, name="encoder"
+ )
+ self.layernorm = (
+ tf.identity
+ if config.use_mean_pooling
+ else keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+ )
+
+ # We are setting the `data_format` like so because from here on we will revert to the
+ # NCHW output format
+ self.pooler = TFData2VecVisionPooler(config, name="pooler") if add_pooling_layer else None
+
+ def get_input_embeddings(self) -> keras.layers.Layer:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ raise NotImplementedError
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ bool_masked_pos: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> tuple | TFData2VecVisionModelOutputWithPooling:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.config.num_hidden_layers
+
+ embedding_output = self.embeddings(pixel_values, bool_masked_pos, training=training)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+ return head_outputs + encoder_outputs[1:]
+
+ return TFData2VecVisionModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "layernorm", None) is not None:
+ if hasattr(self.layernorm, "name"):
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build((None, self.config.hidden_size))
+ if getattr(self, "pooler", None) is not None:
+ with tf.name_scope(self.pooler.name):
+ self.pooler.build(None)
+
+
+class TFData2VecVisionPooler(keras.layers.Layer):
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.layernorm = (
+ keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+ if config.use_mean_pooling
+ else None
+ )
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ if self.layernorm is not None:
+ # Mean pool the final hidden states of the patch tokens
+ patch_tokens = hidden_states[:, 1:, :]
+ pooled_output = self.layernorm(tf.reduce_mean(patch_tokens, axis=1))
+ else:
+ # Pool by simply taking the final hidden state of the [CLS] token
+ pooled_output = hidden_states[:, 0]
+
+ return pooled_output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layernorm", None) is not None:
+ if hasattr(self.layernorm, "name"):
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build((None, self.config.hidden_size))
+
+
+class TFData2VecVisionPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Data2VecVisionConfig
+ base_model_prefix = "data2vec_vision"
+ main_input_name = "pixel_values"
+ _keys_to_ignore_on_load_unexpected = [r"relative_position_index"]
+
+
+DATA2VEC_VISION_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.).
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Args:
+ config ([`Data2VecVisionConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DATA2VEC_VISION_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`BeitImageProcessor.__call__`] for details.
+
+ head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used
+ in eager mode, in graph mode the value will always be set to True.
+
+ training (`bool`, *optional*, defaults to `False``):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare Data2VecVision Model transformer outputting raw hidden-states without any specific head on top.",
+ DATA2VEC_VISION_START_DOCSTRING,
+)
+class TFData2VecVisionModel(TFData2VecVisionPreTrainedModel):
+ def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = False, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.config = config
+
+ self.data2vec_vision = TFData2VecVisionMainLayer(
+ config, add_pooling_layer=add_pooling_layer, name="data2vec_vision"
+ )
+
+ def get_input_embeddings(self):
+ return self.data2vec_vision.get_input_embeddings()
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFData2VecVisionModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ bool_masked_pos: tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> tuple | TFData2VecVisionModelOutputWithPooling:
+ r"""
+ bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`, *optional*):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+ """
+ outputs = self.data2vec_vision(
+ pixel_values=pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "data2vec_vision", None) is not None:
+ with tf.name_scope(self.data2vec_vision.name):
+ self.data2vec_vision.build(None)
+
+
+@add_start_docstrings(
+ """
+ Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of
+ the final hidden states of the patch tokens) e.g. for ImageNet.
+ """,
+ DATA2VEC_VISION_START_DOCSTRING,
+)
+class TFData2VecVisionForImageClassification(TFData2VecVisionPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+ self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=True, name="data2vec_vision")
+
+ # Classifier head
+ self.classifier = keras.layers.Dense(
+ units=config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="classifier",
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=TFSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFSequenceClassifierOutput | tuple:
+ r"""
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.data2vec_vision(
+ pixel_values=pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+ logits = self.classifier(pooled_output)
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "data2vec_vision", None) is not None:
+ with tf.name_scope(self.data2vec_vision.name):
+ self.data2vec_vision.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+class TFData2VecVisionConvModule(keras.layers.Layer):
+ """
+ A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
+ layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int | tuple[int, int],
+ padding: str = "valid",
+ bias: bool = False,
+ dilation: int | tuple[int, int] = 1,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.conv = keras.layers.Conv2D(
+ filters=out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ use_bias=bias,
+ dilation_rate=dilation,
+ name="conv",
+ )
+ self.bn = keras.layers.BatchNormalization(name="bn", momentum=0.9, epsilon=1e-5)
+ self.activation = tf.nn.relu
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ def call(self, input: tf.Tensor) -> tf.Tensor:
+ output = self.conv(input)
+ output = self.bn(output)
+ output = self.activation(output)
+ return output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "conv", None) is not None:
+ with tf.name_scope(self.conv.name):
+ self.conv.build([None, None, None, self.in_channels])
+ if getattr(self, "bn", None) is not None:
+ with tf.name_scope(self.bn.name):
+ self.bn.build((None, None, None, self.out_channels))
+
+
+class TFAdaptiveAvgPool2D(keras.layers.Layer):
+ def __init__(self, output_dims: tuple[int, int], input_ordering: str = "NHWC", **kwargs):
+ super().__init__(**kwargs)
+ self.output_dims = output_dims
+ self.input_ordering = input_ordering
+ if input_ordering not in ("NCHW", "NHWC"):
+ raise ValueError("Unrecognized input_ordering, should be 'NCHW' or 'NHWC'!")
+ self.h_axis = input_ordering.index("H")
+ self.w_axis = input_ordering.index("W")
+
+ def pseudo_1d_pool(self, inputs: tf.Tensor, h_pooling: bool):
+ # Figure out which axis we're pooling on
+ if h_pooling:
+ axis = self.h_axis
+ output_dim = self.output_dims[0]
+ else:
+ axis = self.w_axis
+ output_dim = self.output_dims[1]
+ input_dim = inputs.shape[axis]
+
+ # Figure out the potential pooling windows
+ # This is the key idea - the torch op always uses only two
+ # consecutive pooling window sizes, like 3 and 4. Therefore,
+ # if we pool with both possible sizes, we simply need to gather
+ # the 'correct' pool at each position to reimplement the torch op.
+ small_window = math.ceil(input_dim / output_dim)
+ big_window = small_window + 1
+ if h_pooling:
+ output_dim = self.output_dims[0]
+ small_window_shape = (small_window, 1)
+ big_window_shape = (big_window, 1)
+ else:
+ output_dim = self.output_dims[1]
+ small_window_shape = (1, small_window)
+ big_window_shape = (1, big_window)
+
+ # For resizes to 1, or integer resizes, we can take quick shortcuts
+ if output_dim == input_dim:
+ return inputs
+ elif output_dim == 1:
+ return tf.reduce_mean(inputs, axis=axis, keepdims=True)
+ elif input_dim % output_dim == 0:
+ return tf.nn.avg_pool2d(
+ inputs,
+ ksize=small_window_shape,
+ strides=small_window_shape,
+ padding="VALID",
+ data_format=self.input_ordering,
+ )
+ # When upscaling by an integer factor we can also take a quick shortcut
+ elif output_dim > input_dim and output_dim % input_dim == 0:
+ return tf.repeat(inputs, repeats=output_dim // input_dim, axis=axis)
+
+ # For non-integer resizes, we pool with both possible window sizes and concatenate them
+ if output_dim < input_dim:
+ small_pool = tf.nn.avg_pool2d(
+ inputs, ksize=small_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
+ )
+ big_pool = tf.nn.avg_pool2d(
+ inputs, ksize=big_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
+ )
+ both_pool = tf.concat([small_pool, big_pool], axis=axis)
+ else:
+ # When we're actually upscaling instead, then we build the pools a bit differently
+ small_pool = inputs
+ big_pool = tf.nn.avg_pool2d(
+ inputs, ksize=big_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
+ )
+ both_pool = tf.concat([small_pool, big_pool], axis=axis)
+
+ # We compute vectors of the start and end positions for each pooling window
+ # Each (start, end) pair here corresponds to a single output position
+ window_starts = tf.math.floor((tf.range(output_dim, dtype=tf.float32) * input_dim) / output_dim)
+ window_starts = tf.cast(window_starts, tf.int64)
+ window_ends = tf.math.ceil((tf.range(1, output_dim + 1, dtype=tf.float32) * input_dim) / output_dim)
+ window_ends = tf.cast(window_ends, tf.int64)
+
+ # pool_selector is a boolean array of shape (output_dim,) where 1 indicates that output position
+ # has a big receptive field and 0 indicates that that output position has a small receptive field
+ pool_selector = tf.cast(window_ends - window_starts - small_window, tf.bool)
+
+ # Since we concatenated the small and big pools, we need to do a bit of
+ # pointer arithmetic to get the indices of the big pools
+ small_indices = window_starts
+ big_indices = window_starts + small_pool.shape[axis]
+
+ # Finally, we use the pool_selector to generate a list of indices, one per output position
+ gather_indices = tf.where(pool_selector, big_indices, small_indices)
+
+ # Gathering from those indices yields the final, correct pooling
+ return tf.gather(both_pool, gather_indices, axis=axis)
+
+ def call(self, inputs: tf.Tensor):
+ if self.input_ordering == "NHWC":
+ input_shape = inputs.shape[1:3]
+ else:
+ input_shape = inputs.shape[2:]
+
+ # We break the task down into each possible case
+ # Firstly, if we're resizing down to 1, it's just tf.reduce_mean
+ if self.output_dims[0] == self.output_dims[1] == 1:
+ if self.input_ordering == "NHWC":
+ reduce_dims = [1, 2]
+ else:
+ reduce_dims = [2, 3]
+ return tf.reduce_mean(inputs, axis=reduce_dims, keepdims=True)
+ # Secondly, if we're resizing by an integer factor on both dimensions, we can take a quick shortcut
+ elif input_shape[0] % self.output_dims[0] == 0 and input_shape[1] % self.output_dims[1] == 0:
+ h_resize = int(input_shape[0] // self.output_dims[0])
+ w_resize = int(input_shape[1] // self.output_dims[1])
+ return tf.nn.avg_pool2d(
+ inputs,
+ ksize=(h_resize, w_resize),
+ strides=(h_resize, w_resize),
+ padding="VALID",
+ data_format=self.input_ordering,
+ )
+ else:
+ # Finally, if we can't take the shortcut, we do a 1D pool on each axis. pseudo_1d_pool will take a shortcut
+ # for dimensions where an integer resize is possible. It can also handle upscaling.
+ h_pooled = self.pseudo_1d_pool(inputs, h_pooling=True)
+ return self.pseudo_1d_pool(h_pooled, h_pooling=False)
+
+
+class TFData2VecVisionPyramidPoolingModule(keras.layers.Layer):
+ """
+ Pyramid Pooling Module (PPM) used in PSPNet.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module.
+ channels (int): Channels after modules, before conv_seg.
+
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+ """
+
+ def __init__(self, pool_scales: tuple[int, ...], in_channels: int, out_channels: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.pool_scales = pool_scales
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ self.layer_list = []
+ for idx, pool_scale in enumerate(pool_scales):
+ pool_scale = pool_scale if isinstance(pool_scale, collections.abc.Iterable) else (pool_scale, pool_scale)
+ self.layer_list.append(
+ [
+ TFAdaptiveAvgPool2D(output_dims=pool_scale),
+ TFData2VecVisionConvModule(
+ in_channels=in_channels, out_channels=self.out_channels, kernel_size=1, name=f"{idx}.1"
+ ),
+ ]
+ )
+
+ def call(self, x: tf.Tensor) -> list[tf.Tensor]:
+ ppm_outs = []
+ inputs = x
+
+ for ppm in self.layer_list:
+ for layer_module in ppm:
+ ppm_out = layer_module(x)
+ x = ppm_out
+
+ upsampled_ppm_out = tf.image.resize(ppm_out, size=shape_list(inputs)[1:-1], method="bilinear")
+ ppm_outs.append(upsampled_ppm_out)
+ return ppm_outs
+
+ def build(self, input_shape=None):
+ for layer in self.layer_list:
+ for layer_module in layer:
+ with tf.name_scope(layer_module.name):
+ layer_module.build(None)
+
+
+class TFData2VecVisionUperHead(keras.layers.Layer):
+ """
+ Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
+ [UPerNet](https://huggingface.co/papers/1807.10221).
+
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+ """
+
+ def __init__(self, config: Data2VecVisionConfig, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
+ self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
+ self.channels = config.hidden_size
+ self.classifier = keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
+
+ # PSP Module
+ self.psp_modules = TFData2VecVisionPyramidPoolingModule(
+ self.pool_scales, self.in_channels[-1], self.channels, name="psp_modules"
+ )
+ self.bottleneck = TFData2VecVisionConvModule(
+ self.in_channels[-1] + len(self.pool_scales) * self.channels,
+ self.channels,
+ kernel_size=3,
+ padding="same",
+ name="bottleneck",
+ )
+ # FPN Module
+ self.lateral_convs = []
+ self.fpn_convs = []
+ for idx, in_channels in enumerate(self.in_channels[:-1]): # skip the top layer
+ l_conv = TFData2VecVisionConvModule(
+ in_channels, out_channels=self.channels, kernel_size=1, name=f"lateral_convs.{idx}"
+ )
+ fpn_conv = TFData2VecVisionConvModule(
+ in_channels=self.channels,
+ out_channels=self.channels,
+ kernel_size=3,
+ padding="same",
+ name=f"fpn_convs.{idx}",
+ )
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ self.fpn_bottleneck = TFData2VecVisionConvModule(
+ in_channels=len(self.in_channels) * self.channels,
+ out_channels=self.channels,
+ kernel_size=3,
+ padding="same",
+ name="fpn_bottleneck",
+ )
+
+ def psp_forward(self, inputs):
+ x = inputs[-1]
+ psp_outs = [x]
+ psp_outs.extend(self.psp_modules(x))
+ psp_outs = tf.concat(psp_outs, axis=-1)
+ output = self.bottleneck(psp_outs)
+
+ return output
+
+ def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor:
+ # build laterals
+ laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
+
+ laterals.append(self.psp_forward(encoder_hidden_states))
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ prev_shape = shape_list(laterals[i - 1])[1:-1]
+ laterals[i - 1] = laterals[i - 1] + tf.image.resize(laterals[i], size=prev_shape, method="bilinear")
+
+ # build outputs
+ fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
+ # append psp feature
+ fpn_outs.append(laterals[-1])
+
+ for i in range(used_backbone_levels - 1, 0, -1):
+ fpn_outs[i] = tf.image.resize(fpn_outs[i], size=shape_list(fpn_outs[0])[1:-1], method="bilinear")
+ fpn_outs = tf.concat(fpn_outs, axis=-1)
+ output = self.fpn_bottleneck(fpn_outs)
+ output = self.classifier(output)
+
+ return output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, None, self.channels])
+ if getattr(self, "psp_modules", None) is not None:
+ with tf.name_scope(self.psp_modules.name):
+ self.psp_modules.build(None)
+ if getattr(self, "bottleneck", None) is not None:
+ with tf.name_scope(self.bottleneck.name):
+ self.bottleneck.build(None)
+ if getattr(self, "fpn_bottleneck", None) is not None:
+ with tf.name_scope(self.fpn_bottleneck.name):
+ self.fpn_bottleneck.build(None)
+ for layer in self.lateral_convs:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+ for layer in self.fpn_convs:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+class TFData2VecVisionFCNHead(keras.layers.Layer):
+ """
+ Fully Convolution Networks for Semantic Segmentation. This head is implemented from
+ [FCNNet](https://huggingface.co/papers/1411.4038).
+
+ Args:
+ config (Data2VecVisionConfig): Configuration.
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
+ dilation (int): The dilation rate for convs in the head. Default: 1.
+
+
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+ """
+
+ def __init__(
+ self,
+ config: Data2VecVisionConfig,
+ in_index: int = 2,
+ kernel_size: int = 3,
+ dilation: int | tuple[int, int] = 1,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.in_channels = config.hidden_size
+ self.channels = config.auxiliary_channels
+ self.num_convs = config.auxiliary_num_convs
+ self.concat_input = config.auxiliary_concat_input
+ self.in_index = in_index
+
+ convs = []
+ convs.append(
+ TFData2VecVisionConvModule(
+ in_channels=self.in_channels,
+ out_channels=self.channels,
+ kernel_size=kernel_size,
+ padding="same",
+ dilation=dilation,
+ name="convs.0",
+ )
+ )
+ for i in range(self.num_convs - 1):
+ convs.append(
+ TFData2VecVisionConvModule(
+ in_channels=self.channels,
+ out_channels=self.channels,
+ kernel_size=kernel_size,
+ padding="same",
+ dilation=dilation,
+ name=f"conv_module_{i + 2}",
+ )
+ )
+ if self.num_convs == 0:
+ self.convs = [tf.identity]
+ else:
+ self.convs = convs
+ if self.concat_input:
+ self.conv_cat = TFData2VecVisionConvModule(
+ self.in_channels + self.channels,
+ out_channels=self.channels,
+ kernel_size=kernel_size,
+ padding="same",
+ name="conv_cat",
+ )
+
+ self.classifier = keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
+
+ def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor:
+ # just take the relevant feature maps
+ hidden_states = encoder_hidden_states[self.in_index]
+ output = hidden_states
+ for layer_module in self.convs:
+ output = layer_module(output)
+ if self.concat_input:
+ output = self.conv_cat(tf.concat([hidden_states, output], axis=-1))
+ output = self.classifier(output)
+ return output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, None, self.channels])
+ if getattr(self, "conv_cat", None) is not None:
+ with tf.name_scope(self.conv_cat.name):
+ self.conv_cat.build(None)
+
+
+@add_start_docstrings(
+ """
+ Data2VecVision Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
+ """,
+ DATA2VEC_VISION_START_DOCSTRING,
+)
+class TFData2VecVisionForSemanticSegmentation(TFData2VecVisionPreTrainedModel):
+ def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+ self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=False, name="data2vec_vision")
+
+ # FPNs
+ self.fpn1 = [
+ keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.0"),
+ keras.layers.BatchNormalization(name="fpn1.1", momentum=0.9, epsilon=1e-5),
+ keras.layers.Activation("gelu"),
+ keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.3"),
+ ]
+ self.fpn2 = [keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn2.0")]
+
+ self.fpn3 = tf.identity
+ self.fpn4 = keras.layers.MaxPool2D(pool_size=2, strides=2)
+
+ # Semantic segmentation head(s)
+ self.decode_head = TFData2VecVisionUperHead(config, name="decode_head")
+ self.auxiliary_head = (
+ TFData2VecVisionFCNHead(config, name="auxiliary_head") if config.use_auxiliary_head else None
+ )
+
+ def compute_loss(self, logits, auxiliary_logits, labels):
+ # upsample logits to the images' original size
+ if len(shape_list(labels)) > 3:
+ label_interp_shape = shape_list(labels)[1:-1]
+ else:
+ label_interp_shape = shape_list(labels)[-2:]
+
+ upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
+ if auxiliary_logits is not None:
+ upsampled_auxiliary_logits = tf.image.resize(auxiliary_logits, size=label_interp_shape, method="bilinear")
+ # compute weighted loss
+ loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
+
+ # Copied from https://www.tensorflow.org/text/tutorials/transformer#loss_and_metrics.
+ # Utility to mask the index to ignore during computing the loss.
+ def masked_loss(real, pred):
+ mask = tf.math.logical_not(tf.math.equal(real, self.config.semantic_loss_ignore_index))
+ loss_ = loss_fct(real, pred)
+ mask = tf.cast(mask, dtype=loss_.dtype)
+ loss_ *= mask
+ reduced_masked_loss = tf.reduce_sum(loss_) / tf.reduce_sum(mask)
+ return tf.reshape(reduced_masked_loss, (1,))
+
+ main_loss = masked_loss(labels, upsampled_logits)
+ auxiliary_loss = masked_loss(labels, upsampled_auxiliary_logits)
+ loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
+
+ return loss
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ labels: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ ) -> tuple | TFSemanticSegmenterOutput:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, TFData2VecVisionForSemanticSegmentation
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base")
+ >>> model = TFData2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> # logits are of shape (batch_size, num_labels, height, width)
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if labels is not None and self.config.num_labels == 1:
+ raise ValueError("The number of labels should be greater than one")
+
+ outputs = self.data2vec_vision(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True, # we need the intermediate hidden states
+ return_dict=return_dict,
+ )
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ # only keep certain features, and reshape
+ # note that we do +1 as the encoder_hidden_states also includes the initial embeddings
+ features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
+ patch_resolution = self.config.image_size // self.config.patch_size
+
+ def reshape_features(x):
+ # We do it this way so TF can always infer the non-batch dims at compile time
+ x = tf.reshape(x, (-1, patch_resolution, patch_resolution, self.config.hidden_size))
+ return x
+
+ features = [reshape_features(x[:, 1:, :]) for x in features]
+
+ # apply FPNs
+ ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
+ for module in ops[0]:
+ features[0] = module(features[0])
+ features[1] = ops[1][0](features[1])
+ for i in range(len(features[2:])):
+ features[i + 2] = ops[i + 2](features[i + 2])
+
+ logits = self.decode_head(features)
+ # Transpose the logits to maintain consistency in the output formats.
+ transposed_logits = tf.transpose(logits, perm=[0, 3, 1, 2])
+
+ auxiliary_logits = None
+ if self.auxiliary_head is not None:
+ auxiliary_logits = self.auxiliary_head(features)
+
+ loss = None
+ if labels is not None:
+ loss = self.compute_loss(logits, auxiliary_logits, labels)
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (logits,) + outputs[1:]
+ else:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSemanticSegmenterOutput(
+ loss=loss,
+ logits=transposed_logits,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "data2vec_vision", None) is not None:
+ with tf.name_scope(self.data2vec_vision.name):
+ self.data2vec_vision.build(None)
+ if getattr(self, "decode_head", None) is not None:
+ with tf.name_scope(self.decode_head.name):
+ self.decode_head.build(None)
+ if getattr(self, "auxiliary_head", None) is not None:
+ with tf.name_scope(self.auxiliary_head.name):
+ self.auxiliary_head.build(None)
+ if getattr(self, "fpn1", None) is not None:
+ with tf.name_scope(self.fpn1[0].name):
+ self.fpn1[0].build([None, None, None, self.config.hidden_size])
+ with tf.name_scope(self.fpn1[1].name):
+ self.fpn1[1].build((None, None, None, self.config.hidden_size))
+ with tf.name_scope(self.fpn1[3].name):
+ self.fpn1[3].build([None, None, None, self.config.hidden_size])
+ if getattr(self, "fpn2", None) is not None:
+ with tf.name_scope(self.fpn2[0].name):
+ self.fpn2[0].build([None, None, None, self.config.hidden_size])
+
+
+__all__ = [
+ "TFData2VecVisionForImageClassification",
+ "TFData2VecVisionForSemanticSegmentation",
+ "TFData2VecVisionModel",
+ "TFData2VecVisionPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modular_data2vec_audio.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modular_data2vec_audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..91cb04730e4aec01edff0f35701eb12c3ee5af3d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modular_data2vec_audio.py
@@ -0,0 +1,267 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Data2VecText model."""
+
+import math
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import Wav2Vec2BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ..wav2vec2.modeling_wav2vec2 import (
+ Wav2Vec2Adapter,
+ Wav2Vec2Encoder,
+ Wav2Vec2FeatureEncoder,
+ Wav2Vec2FeatureProjection,
+ Wav2Vec2ForAudioFrameClassification,
+ Wav2Vec2ForCTC,
+ Wav2Vec2ForSequenceClassification,
+ Wav2Vec2ForXVector,
+ Wav2Vec2Model,
+ Wav2Vec2PreTrainedModel,
+ Wav2Vec2SamePadLayer,
+)
+from .configuration_data2vec_audio import Data2VecAudioConfig
+
+
+class Data2VecAudioConvLayer(GradientCheckpointingLayer):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ hidden_states = hidden_states.transpose(-2, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(-2, -1)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class Data2VecAudioPadLayer(Wav2Vec2SamePadLayer):
+ pass
+
+
+class Data2VecAudioPositionalConvLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.conv_pos_kernel_size,
+ padding=config.conv_pos_kernel_size // 2,
+ groups=config.num_conv_pos_embedding_groups,
+ )
+
+ self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size)
+ self.activation = ACT2FN[config.feat_extract_activation]
+ # no learnable parameters
+ self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.padding(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class Data2VecAudioPositionalConvEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)]
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.transpose(1, 2)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class Data2VecAudioFeatureEncoder(Wav2Vec2FeatureEncoder):
+ def __init__(self, config):
+ nn.Module.__init__(self)
+ self.conv_layers = nn.ModuleList(
+ [Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
+ )
+ self.gradient_checkpointing = False
+ self._requires_grad = True
+
+
+class Data2VecAudioFeatureProjection(Wav2Vec2FeatureProjection):
+ pass
+
+
+class Data2VecAudioEncoder(Wav2Vec2Encoder):
+ pass
+
+
+class Data2VecAudioAdapter(Wav2Vec2Adapter):
+ pass
+
+
+class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
+ config: Data2VecAudioConfig
+ base_model_prefix = "data2vec_audio"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, Data2VecAudioFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, Data2VecAudioPositionalConvLayer):
+ nn.init.constant_(module.conv.bias, 0)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ if module.bias is not None:
+ module.bias.data.zero_()
+ if module.weight is not None:
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.kaiming_normal_(module.weight)
+
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+
+ def _get_adapters(self):
+ raise AttributeError("Not needed for Data2VecAudio")
+
+ def init_adapter_layers(self):
+ raise AttributeError("Not needed for Data2VecAudio")
+
+ def load_adapter(self):
+ raise AttributeError("Not needed for Data2VecAudio")
+
+
+Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput
+
+
+class Data2VecAudioModel(Data2VecAudioPreTrainedModel, Wav2Vec2Model):
+ def __init__(self, config: Data2VecAudioConfig):
+ Data2VecAudioPreTrainedModel.__init__(self, config)
+ self.config = config
+ self.feature_extractor = Data2VecAudioFeatureEncoder(config)
+ self.feature_projection = Data2VecAudioFeatureProjection(config)
+
+ # model only needs masking vector if mask prob is > 0.0
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
+
+ self.encoder = Data2VecAudioEncoder(config)
+
+ self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def freeze_feature_extractor(self):
+ raise AttributeError("Not needed for Data2VecAudio")
+
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.feature_extractor._freeze_parameters()
+
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel, Wav2Vec2ForCTC):
+ def __init__(self, config):
+ Data2VecAudioPreTrainedModel.__init__(self, config)
+
+ self.data2vec_audio = Data2VecAudioModel(config)
+ self.dropout = nn.Dropout(config.final_dropout)
+
+ if config.vocab_size is None:
+ raise ValueError(
+ f"You are trying to instantiate {self.__class__} with a configuration that "
+ "does not define the vocabulary size of the language model head. Please "
+ "instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+ "or define `vocab_size` of your model's configuration."
+ )
+ output_hidden_size = (
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
+ )
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def freeze_base_model(self):
+ raise AttributeError("Not needed for Data2VecAudio")
+
+ def tie_weights(self):
+ raise AttributeError("Not needed for Data2VecAudio")
+
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+class Data2VecAudioForSequenceClassification(Wav2Vec2ForSequenceClassification):
+ pass
+
+
+class Data2VecAudioForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification):
+ pass
+
+
+class Data2VecAudioForXVector(Wav2Vec2ForXVector):
+ pass
+
+
+__all__ = [
+ "Data2VecAudioForAudioFrameClassification",
+ "Data2VecAudioForCTC",
+ "Data2VecAudioForSequenceClassification",
+ "Data2VecAudioForXVector",
+ "Data2VecAudioModel",
+ "Data2VecAudioPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cce0f34c778d21a81d03940e7a7951707d898c86
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_dbrx import *
+ from .modeling_dbrx import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/configuration_dbrx.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/configuration_dbrx.py
new file mode 100644
index 0000000000000000000000000000000000000000..17b6b2a368cc5f13783ae7b333dc7af78aee9d05
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/configuration_dbrx.py
@@ -0,0 +1,232 @@
+# coding=utf-8
+# Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DBRX model configuration"""
+
+from typing import Any, Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DbrxAttentionConfig(PretrainedConfig):
+ """Configuration class for Dbrx Attention.
+
+ [`DbrxAttention`] class. It is used to instantiate attention layers
+ according to the specified arguments, defining the layers architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ attn_pdrop (`float`, *optional*, defaults to 0.0):
+ The dropout probability for the attention layers.
+ clip_qkv (`float`, *optional*):
+ If set, clip the queries, keys, and values in the attention layer to this value.
+ kv_n_heads (`int`, *optional*, defaults to 1): For grouped_query_attention only, allow user to specify number of kv heads.
+ rope_theta (`float`, *optional*, defaults to 10000.0): The base frequency for rope.
+ """
+
+ base_config_key = "attn_config"
+
+ def __init__(
+ self,
+ attn_pdrop: float = 0.0,
+ clip_qkv: Optional[float] = None,
+ kv_n_heads: int = 1,
+ rope_theta: float = 10000.0,
+ **kwargs: Any,
+ ):
+ super().__init__(**kwargs)
+ self.attn_pdrop = attn_pdrop
+ self.clip_qkv = clip_qkv
+ self.kv_n_heads = kv_n_heads
+ self.rope_theta = rope_theta
+
+ for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype", "dtype"]:
+ if k in kwargs:
+ kwargs.pop(k)
+ if len(kwargs) != 0:
+ raise ValueError(f"Found unknown {kwargs=}")
+
+
+class DbrxFFNConfig(PretrainedConfig):
+ """Configuration class for Dbrx FFN.
+
+ [`DbrxFFN`] class. It is used to instantiate feedforward layers according to
+ the specified arguments, defining the layers architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ ffn_act_fn (`dict`, *optional*, defaults to `None`): A dict specifying activation function for the FFN.
+ The dict should have a key 'name' with the value being the name of the activation function along with
+ any additional keyword arguments. If `None`, then set to `{"name": "silu"}`.
+ ffn_hidden_size (`int`, *optional*, defaults to 3584): The hidden size of the feedforward network.
+ moe_num_experts (`int`, *optional*, defaults to 4): The number of experts in the mixture of experts layer.
+ moe_top_k (`int`, *optional*, defaults to 1): The number of experts to use in the mixture of experts layer.
+ moe_jitter_eps (`float`, *optional*, defaults to `None`): If not `None`, the jitter epsilon for the mixture of experts layer.
+ moe_loss_weight (`float`, *optional*, defaults to 0.01): The loss weight for the mixture of experts layer.
+ moe_normalize_expert_weights (`float`, *optional*, defaults to 1.0): The normalization factor for the expert weights.
+ """
+
+ base_config_key = "ffn_config"
+
+ def __init__(
+ self,
+ ffn_act_fn: Optional[dict] = None,
+ ffn_hidden_size: int = 3584,
+ moe_num_experts: int = 4,
+ moe_top_k: int = 1,
+ moe_jitter_eps: Optional[float] = None,
+ moe_loss_weight: float = 0.01,
+ moe_normalize_expert_weights: Optional[float] = 1.0,
+ **kwargs: Any,
+ ):
+ super().__init__()
+ if ffn_act_fn is None:
+ ffn_act_fn = {"name": "silu"}
+ self.ffn_act_fn = ffn_act_fn
+ self.ffn_hidden_size = ffn_hidden_size
+ self.moe_num_experts = moe_num_experts
+ self.moe_top_k = moe_top_k
+ self.moe_jitter_eps = moe_jitter_eps
+ self.moe_loss_weight = moe_loss_weight
+ self.moe_normalize_expert_weights = moe_normalize_expert_weights
+
+ for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype", "dtype"]:
+ if k in kwargs:
+ kwargs.pop(k)
+ if len(kwargs) != 0:
+ raise ValueError(f"Found unknown {kwargs=}")
+
+
+class DbrxConfig(PretrainedConfig):
+ r"""
+
+ This is the configuration class to store the configuration of a [`DbrxModel`]. It is used to instantiate a Dbrx model according to the
+ specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a different configuration to that of the [databricks/dbrx-instruct](https://huggingface.co/databricks/dbrx-instruct) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ d_model (`int`, *optional*, defaults to 2048):
+ Dimensionality of the embeddings and hidden states.
+ n_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ n_layers (`int`, *optional*, defaults to 24):
+ Number of hidden layers in the Transformer encoder.
+ max_seq_len (`int`, *optional*, defaults to 2048):
+ The maximum sequence length of the model.
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`DbrxModel`].
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
+ The dropout probability applied to the attention output before combining with residual.
+ emb_pdrop (`float`, *optional*, defaults to 0.0):
+ The dropout probability for the embedding layer.
+ attn_config (`dict`, *optional*):
+ A dictionary used to configure the model's attention module.
+ ffn_config (`dict`, *optional*):
+ A dictionary used to configure the model's FFN module.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ output_router_logits (`bool`, *optional*, defaults to `False`):
+ Whether or not the router logits should be returned by the model. Enabling this will also
+ allow the model to output the auxiliary loss. See [here]() for more details.
+
+
+ Example:
+ ```python
+ >>> from transformers import DbrxConfig, DbrxModel
+
+ >>> # Initializing a Dbrx configuration
+ >>> configuration = DbrxConfig(n_layers=2, d_model=256, n_heads=8, vocab_size=128)
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = DbrxModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "dbrx"
+ sub_configs = {"attn_config": DbrxAttentionConfig, "ffn_config": DbrxFFNConfig}
+ attribute_map = {
+ "num_attention_heads": "n_heads",
+ "hidden_size": "d_model",
+ "num_hidden_layers": "n_layers",
+ "max_position_embeddings": "max_seq_len",
+ }
+
+ def __init__(
+ self,
+ d_model: int = 2048,
+ n_heads: int = 16,
+ n_layers: int = 24,
+ max_seq_len: int = 2048,
+ vocab_size: int = 32000,
+ resid_pdrop: float = 0.0,
+ emb_pdrop: float = 0.0,
+ attn_config: Optional[DbrxAttentionConfig] = None,
+ ffn_config: Optional[DbrxFFNConfig] = None,
+ use_cache: bool = True,
+ initializer_range: float = 0.02,
+ output_router_logits: bool = False,
+ **kwargs: Any,
+ ):
+ if attn_config is None:
+ self.attn_config = DbrxAttentionConfig()
+ elif isinstance(attn_config, dict):
+ self.attn_config = DbrxAttentionConfig(**attn_config)
+ else:
+ self.attn_config = attn_config
+
+ if ffn_config is None:
+ self.ffn_config = DbrxFFNConfig()
+ elif isinstance(ffn_config, dict):
+ self.ffn_config = DbrxFFNConfig(**ffn_config)
+ else:
+ self.ffn_config = ffn_config
+
+ self.d_model = d_model
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.max_seq_len = max_seq_len
+ self.vocab_size = vocab_size
+ self.resid_pdrop = resid_pdrop
+ self.emb_pdrop = emb_pdrop
+ self.use_cache = use_cache
+ self.initializer_range = initializer_range
+ self.output_router_logits = output_router_logits
+ self.num_key_value_heads = self.attn_config.kv_n_heads
+
+ tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
+ if tie_word_embeddings:
+ raise ValueError("tie_word_embeddings is not supported for DBRX models.")
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+__all__ = ["DbrxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/modeling_dbrx.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/modeling_dbrx.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f3a423213cba3ef11b95a36a3efa5aa01da69c2
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/modeling_dbrx.py
@@ -0,0 +1,1248 @@
+# coding=utf-8
+# Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DBRX model."""
+
+import math
+from typing import Any, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_dbrx import DbrxConfig
+
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+if is_flash_attn_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+logger = logging.get_logger(__name__)
+
+
+class DbrxRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ self.inv_freq.to(x.device)
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def load_balancing_loss_func(
+ gate_probabilities: torch.Tensor,
+ num_experts: int,
+ top_k: int,
+ attention_mask: Optional[torch.Tensor],
+) -> torch.Tensor:
+ r"""Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+
+ Args:
+ gate_logits (Union[`torch.Tensor`, tuple[torch.Tensor]):
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+ shape [batch_size X sequence_length, num_experts].
+ num_experts (`int`):
+ Number of experts.
+ top_k (`int`):
+ The number of experts each token is routed to.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention_mask used in forward function
+ shape [batch_size X sequence_length] if not None.
+
+ Returns:
+ The auxiliary loss.
+ """
+ if gate_probabilities is None or not isinstance(gate_probabilities, tuple):
+ return torch.tensor(0.0)
+
+ if isinstance(gate_probabilities, tuple):
+ compute_device = gate_probabilities[0].device
+ routing_weights = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_probabilities], dim=0)
+
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+ if attention_mask is None:
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
+ else:
+ batch_size, sequence_length = attention_mask.shape
+ num_hidden_layers = routing_weights.shape[0] // (batch_size * sequence_length)
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+ expert_attention_mask = (
+ attention_mask[None, :, :, None, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+ .reshape(-1, top_k, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+ expert_attention_mask, dim=0
+ )
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+ router_per_expert_attention_mask = (
+ attention_mask[None, :, :, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
+ .reshape(-1, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+ router_per_expert_attention_mask, dim=0
+ )
+
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ return overall_loss * num_experts
+
+
+class DbrxAttention(nn.Module):
+ """Multi-head self attention."""
+
+ def __init__(self, config: DbrxConfig, block_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.d_model
+ self.num_heads = config.n_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.max_position_embeddings = config.max_seq_len
+ self.block_idx = block_idx
+ if block_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `block_idx` is not recommended and will "
+ + "lead to errors during the forward call if caching is used. Please make sure to provide a `block_idx` "
+ + "when creating this class."
+ )
+
+ attn_config = config.attn_config
+ self.attn_pdrop = attn_config.attn_pdrop
+ self.clip_qkv = attn_config.clip_qkv
+ self.num_key_value_heads = attn_config.kv_n_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.rope_theta = attn_config.rope_theta
+ self.is_causal = True
+
+ self.Wqkv = nn.Linear(
+ self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False
+ )
+ self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
+ self.rotary_emb = DbrxRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_ids: torch.LongTensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Any,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv_states = self.Wqkv(hidden_states)
+ min_val = -self.clip_qkv if self.clip_qkv is not None else None
+ max_val = self.clip_qkv
+ qkv_states = qkv_states.clamp(min=min_val, max=max_val)
+
+ query_states, key_states, value_states = qkv_states.split(
+ [
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ self.num_key_value_heads * self.head_dim,
+ ],
+ dim=2,
+ )
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.block_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attn_pdrop, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ + f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+class DbrxFlashAttention2(DbrxAttention):
+ """Dbrx flash attention module.
+
+ This module inherits from `DbrxAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it
+ calls the public API of flash attention.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Any,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ if isinstance(past_key_values, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+ logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.")
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv_states = self.Wqkv(hidden_states)
+ if self.clip_qkv is not None:
+ qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
+
+ query_states, key_states, value_states = qkv_states.split(
+ [
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ self.num_key_value_heads * self.head_dim,
+ ],
+ dim=2,
+ )
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.block_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attn_pdrop if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+ input_dtype = query_states.dtype
+ device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = (
+ torch.get_autocast_dtype(device_type)
+ if hasattr(torch, "get_autocast_dtype")
+ else torch.get_autocast_gpu_dtype()
+ )
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = query_states.dtype
+
+ logger.warning_once(
+ "The input hidden states seems to be silently casted in float32, this might be "
+ + "related to the fact you have upcasted embedding or layer norm layers in "
+ + f"float32. We will cast back the input in {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+class DbrxSdpaAttention(DbrxAttention):
+ """
+ Dbrx attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `DbrxAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "DbrxModel is using DbrxSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv_states = self.Wqkv(hidden_states)
+ if self.clip_qkv is not None:
+ qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
+
+ query_states, key_states, value_states = qkv_states.split(
+ [
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ self.num_key_value_heads * self.head_dim,
+ ],
+ dim=2,
+ )
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.block_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = causal_mask is None and q_len > 1
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attn_pdrop if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, None
+
+
+DBRX_ATTENTION_CLASSES = {
+ "eager": DbrxAttention,
+ "flash_attention_2": DbrxFlashAttention2,
+ "sdpa": DbrxSdpaAttention,
+}
+
+
+class DbrxNormAttentionNorm(nn.Module):
+ def __init__(self, config: DbrxConfig, block_idx: Optional[int] = None):
+ super().__init__()
+ self.block_idx = block_idx
+ self.resid_pdrop = config.resid_pdrop
+ self.norm_1 = nn.LayerNorm(config.d_model, bias=False)
+ self.attn = DBRX_ATTENTION_CLASSES[config._attn_implementation](
+ config=config,
+ block_idx=block_idx,
+ )
+ self.norm_2 = nn.LayerNorm(config.d_model, bias=False)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_ids: torch.LongTensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Any,
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
+ residual_states = hidden_states
+ hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype)
+
+ hidden_states, attn_weights = self.attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
+ hidden_states = hidden_states + residual_states
+
+ residual_states = hidden_states
+ hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype)
+
+ return residual_states, hidden_states, attn_weights
+
+
+class DbrxRouter(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ moe_num_experts: int,
+ moe_top_k: int,
+ moe_jitter_eps: Optional[float],
+ moe_normalize_expert_weights: Optional[float],
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.moe_num_experts = moe_num_experts
+ self.moe_top_k = moe_top_k
+ self.moe_jitter_eps = moe_jitter_eps
+ self.moe_normalize_expert_weights = moe_normalize_expert_weights
+
+ self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
+ if self.training and self.moe_jitter_eps is not None:
+ hidden_states *= torch.empty_like(hidden_states).uniform_(
+ 1.0 - self.moe_jitter_eps, 1.0 + self.moe_jitter_eps
+ )
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32)
+ top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
+
+ top_weights_scale = (
+ torch.norm(top_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True)
+ if self.moe_normalize_expert_weights is not None
+ else 1.0
+ )
+ top_weights = top_weights / top_weights_scale
+
+ weights = weights.to(hidden_states.dtype)
+ top_weights = top_weights.to(hidden_states.dtype)
+ return weights, top_weights, top_experts
+
+
+class DbrxExpertGLU(nn.Module):
+ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.ffn_hidden_size = ffn_hidden_size
+ self.moe_num_experts = moe_num_experts
+
+ self.w1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
+ self.v1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
+ self.w2 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
+
+ act_fn_name = ffn_act_fn.get("name", "silu")
+ self.activation_fn = ACT2FN[act_fn_name]
+
+ def forward(
+ self, x: torch.Tensor, expert_w1: torch.Tensor, expert_v1: torch.Tensor, expert_w2: torch.Tensor
+ ) -> torch.Tensor:
+ gate_proj = x.matmul(expert_w1.t())
+ up_proj = x.matmul(expert_v1.t())
+ gate_proj = self.activation_fn(gate_proj)
+ intermediate_states = gate_proj * up_proj
+ down_proj = intermediate_states.matmul(expert_w2)
+ return down_proj
+
+
+class DbrxExperts(nn.Module):
+ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict):
+ super().__init__()
+ self.moe_num_experts = moe_num_experts
+ self.mlp = DbrxExpertGLU(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=moe_num_experts,
+ ffn_act_fn=ffn_act_fn,
+ )
+
+ def forward(
+ self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor
+ ) -> torch.Tensor:
+ bsz, q_len, hidden_size = x.shape
+ x = x.view(-1, hidden_size)
+ out = torch.zeros_like(x)
+
+ expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
+ # Chunk experts at once to avoid storing full parameter multiple times in autograd
+ w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
+ self.moe_num_experts, dim=0
+ )
+ v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
+ self.moe_num_experts, dim=0
+ )
+ w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
+ self.moe_num_experts, dim=0
+ )
+ w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked]
+ v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked]
+ w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked]
+ for expert_idx in range(0, self.moe_num_experts):
+ # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: dynamic shape operator: aten.nonzero.default`)
+ # (set torch._dynamo.config.capture_dynamic_output_shape_ops = True may help but not tested)
+ topk_idx, token_idx = torch.where(expert_mask[expert_idx])
+ if token_idx.shape[0] == 0:
+ continue
+
+ token_list = token_idx
+ topk_list = topk_idx
+
+ expert_tokens = x[None, token_list].reshape(-1, hidden_size)
+ expert_out = (
+ self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx])
+ * top_weights[token_list, topk_list, None]
+ )
+
+ out.index_add_(0, token_idx, expert_out)
+
+ out = out.reshape(bsz, q_len, hidden_size)
+ return out
+
+
+class DbrxFFN(nn.Module):
+ def __init__(self, config: DbrxConfig):
+ super().__init__()
+
+ ffn_config = config.ffn_config
+ self.router = DbrxRouter(
+ hidden_size=config.d_model,
+ moe_num_experts=ffn_config.moe_num_experts,
+ moe_top_k=ffn_config.moe_top_k,
+ moe_jitter_eps=ffn_config.moe_jitter_eps,
+ moe_normalize_expert_weights=ffn_config.moe_normalize_expert_weights,
+ )
+
+ self.experts = DbrxExperts(
+ hidden_size=config.d_model,
+ ffn_hidden_size=ffn_config.ffn_hidden_size,
+ moe_num_experts=ffn_config.moe_num_experts,
+ ffn_act_fn=ffn_config.ffn_act_fn,
+ )
+
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ weights, top_weights, top_experts = self.router(x)
+ out = self.experts(x, weights, top_weights, top_experts)
+ return out, weights
+
+
+class DbrxBlock(GradientCheckpointingLayer):
+ def __init__(self, config: DbrxConfig, block_idx: int):
+ super().__init__()
+ self.hidden_size = config.d_model
+ self.resid_pdrop = config.resid_pdrop
+ self.block_idx = block_idx
+ self.norm_attn_norm = DbrxNormAttentionNorm(
+ config=config,
+ block_idx=block_idx,
+ )
+ self.ffn = DbrxFFN(config=config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ output_router_logits: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Any,
+ ) -> Union[
+ tuple[torch.Tensor],
+ tuple[torch.Tensor, Optional[torch.Tensor]],
+ tuple[torch.Tensor, Optional[Cache]],
+ tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]],
+ tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
+ tuple[torch.Tensor, Optional[Cache], Optional[torch.Tensor]],
+ tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache], Optional[torch.Tensor]],
+ ]:
+ """Forward function for DbrxBlock.
+
+ Args:
+ hidden_states (`torch.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ position_ids (`torch.LongTensor`): position ids of shape `(batch, seq_len)`
+ attention_mask (`torch.Tensor`, *optional*): attention mask of size (batch_size, sequence_length)
+ if flash attention is used or (batch_size, 1, query_sequence_length, key_sequence_length)
+ if default attention is used.
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all
+ attention layers. See `attentions` under returned tensors for more detail.
+ output_router_logits (`bool`, *optional*): Whether or not to return the router logits.
+ use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are
+ returned and can be used to speed up decoding (see `past_key_values`).
+ cache_position (`torch.LongTensor`, *optional*): position ids of the cache
+ """
+
+ # Norm + Attention + Norm
+ resid_states, hidden_states, self_attn_weights = self.norm_attn_norm(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ # Fully Connected
+ hidden_states, router_logits = self.ffn(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
+ hidden_states = resid_states + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if output_router_logits:
+ outputs += (router_logits,)
+
+ return outputs
+
+
+@auto_docstring
+class DbrxPreTrainedModel(PreTrainedModel):
+ config: DbrxConfig
+ base_model_prefix = "transformer"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["DbrxBlock"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+
+ def _init_weights(self, module: nn.Module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, DbrxExpertGLU):
+ module.w1.data.normal_(mean=0.0, std=std)
+ module.v1.data.normal_(mean=0.0, std=std)
+ module.w2.data.normal_(mean=0.0, std=std)
+
+
+@auto_docstring
+class DbrxModel(DbrxPreTrainedModel):
+ """Transformer decoder consisting of *config.num_hidden_layers*. Each layer is a [`DbrxBlock`] layer.
+
+ Args:
+ config ([`DbrxConfig`]): Model configuration class with all parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+ """
+
+ def __init__(self, config: DbrxConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.emb_pdrop = config.emb_pdrop
+
+ self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
+ self.blocks = nn.ModuleList([DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)])
+ self.norm_f = nn.LayerNorm(config.d_model, bias=False)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Embedding:
+ return self.wte
+
+ def set_input_embeddings(self, value: nn.Embedding):
+ self.wte = value
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs, # NOOP kwargs, for now
+ ) -> Union[tuple, MoeModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+
+ inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_router_logits = () if output_router_logits else None
+
+ for block in self.blocks:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ block_outputs = block(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_router_logits=output_router_logits,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = block_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (block_outputs[1],)
+
+ if output_router_logits:
+ all_router_logits += (block_outputs[-1],)
+
+ hidden_states = self.norm_f(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_router_logits]
+ if v is not None
+ )
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ router_logits=all_router_logits,
+ )
+
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, "BlockMask"],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+@auto_docstring(
+ custom_intro="""
+ The DBRX Model transformer for causal language modeling.
+ """
+)
+class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin):
+ def __init__(self, config: DbrxConfig):
+ super().__init__(config)
+ self.transformer = DbrxModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.moe_loss_weight = config.ffn_config.moe_loss_weight
+ self.num_experts = config.ffn_config.moe_num_experts
+ self.num_experts_per_tok = config.ffn_config.moe_top_k
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Embedding:
+ return self.transformer.get_input_embeddings()
+
+ def set_input_embeddings(self, value: nn.Embedding):
+ self.transformer.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Linear:
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings: nn.Linear):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder: DbrxModel):
+ self.transformer = decoder
+
+ def get_decoder(self) -> DbrxModel:
+ return self.transformer
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> Union[tuple, MoeCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >> from transformers import AutoTokenizer, DbrxForCausalLM
+
+ >> model = DbrxForCausalLM.from_pretrained("databricks/dbrx-instruct")
+ >> tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct")
+
+ >> prompt = "Hey, are you conscious? Can you talk to me?"
+ >> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >> # Generate
+ >> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.transformer(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_router_logits=output_router_logits,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ # No upscaling to float was ever done for Dbrx
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits if return_dict else outputs[-1],
+ self.num_experts,
+ self.num_experts_per_tok,
+ attention_mask,
+ )
+ if labels is not None and loss is not None:
+ loss += self.moe_loss_weight * aux_loss.to(loss.device) # make sure to reside in the same device
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ if output_router_logits:
+ output = (aux_loss,) + output
+ return (loss,) + output if loss is not None else output
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
+
+
+__all__ = ["DbrxForCausalLM", "DbrxModel", "DbrxPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..98236a86d7a1e8b4ff16b53fb3ff37befbf1d7ac
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_deit import *
+ from .feature_extraction_deit import *
+ from .image_processing_deit import *
+ from .image_processing_deit_fast import *
+ from .modeling_deit import *
+ from .modeling_tf_deit import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/configuration_deit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/configuration_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a321ebe293e191e7bbce29b528dfa2f6b00d141
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/configuration_deit.py
@@ -0,0 +1,152 @@
+# coding=utf-8
+# Copyright 2021 Facebook AI Research (FAIR) and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DeiT model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeiTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DeiTModel`]. It is used to instantiate an DeiT
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the DeiT
+ [facebook/deit-base-distilled-patch16-224](https://huggingface.co/facebook/deit-base-distilled-patch16-224)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ encoder_stride (`int`, *optional*, defaults to 16):
+ Factor to increase the spatial resolution by in the decoder head for masked image modeling.
+ pooler_output_size (`int`, *optional*):
+ Dimensionality of the pooler layer. If None, defaults to `hidden_size`.
+ pooler_act (`str`, *optional*, defaults to `"tanh"`):
+ The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and
+ Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are
+ supported for Tensorflow.
+
+ Example:
+
+ ```python
+ >>> from transformers import DeiTConfig, DeiTModel
+
+ >>> # Initializing a DeiT deit-base-distilled-patch16-224 style configuration
+ >>> configuration = DeiTConfig()
+
+ >>> # Initializing a model (with random weights) from the deit-base-distilled-patch16-224 style configuration
+ >>> model = DeiTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "deit"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ qkv_bias=True,
+ encoder_stride=16,
+ pooler_output_size=None,
+ pooler_act="tanh",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.encoder_stride = encoder_stride
+ self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size
+ self.pooler_act = pooler_act
+
+
+class DeiTOnnxConfig(OnnxConfig):
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-4
+
+
+__all__ = ["DeiTConfig", "DeiTOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/feature_extraction_deit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/feature_extraction_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..d040fd08395f8e921ec688228d7d5faa8963ab81
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/feature_extraction_deit.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for DeiT."""
+
+import warnings
+
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_deit import DeiTImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class DeiTFeatureExtractor(DeiTImageProcessor):
+ def __init__(self, *args, **kwargs) -> None:
+ warnings.warn(
+ "The class DeiTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
+ " use DeiTImageProcessor instead.",
+ FutureWarning,
+ )
+ super().__init__(*args, **kwargs)
+
+
+__all__ = ["DeiTFeatureExtractor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/image_processing_deit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/image_processing_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e2f6c3b5ae5f0f1cf2eb1727d2e3235443b81b9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/image_processing_deit.py
@@ -0,0 +1,301 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for DeiT."""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+from ...utils.import_utils import requires
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class DeiTImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a DeiT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in `preprocess`.
+ size (`dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
+ Size of the image after `resize`. Can be overridden by `size` in `preprocess`.
+ resample (`PILImageResampling` filter, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
+ is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in `preprocess`.
+ crop_size (`dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Desired output size when applying center-cropping. Can be overridden by `crop_size` in `preprocess`.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PIL.Image.BICUBIC,
+ do_center_crop: bool = True,
+ crop_size: Optional[dict[str, int]] = None,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_rescale: bool = True,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 256, "width": 256}
+ size = get_size_dict(size)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
+ output_size = (size["height"], size["width"])
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample=None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[dict[str, int]] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after `resize`.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ PILImageResampling filter to use if resizing the image Only has an effect if `do_resize` is set to
+ `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
+ padded with zeros and then cropped
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - `None`: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
+
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ all_images = []
+ for image in images:
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_center_crop:
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ all_images.append(image)
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ for image in all_images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["DeiTImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/image_processing_deit_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/image_processing_deit_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..3aafeaf50c09455cffeecb3776eb3598c8ceccf2
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/image_processing_deit_fast.py
@@ -0,0 +1,41 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for DeiT."""
+
+from ...image_processing_utils_fast import BaseImageProcessorFast
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ PILImageResampling,
+)
+from ...utils import auto_docstring
+
+
+@auto_docstring
+class DeiTImageProcessorFast(BaseImageProcessorFast):
+ # To be checked against the slow image processor
+ # None values left after checking can be removed
+ resample = PILImageResampling.BICUBIC
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ size = {"height": 256, "width": 256}
+ crop_size = {"height": 224, "width": 224}
+ do_resize = True
+ do_center_crop = True
+ do_rescale = True
+ do_normalize = True
+
+
+__all__ = ["DeiTImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/modeling_deit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/modeling_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddb03c053f1ee08f011e650daad794821205ff33
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/modeling_deit.py
@@ -0,0 +1,791 @@
+# coding=utf-8
+# Copyright 2021 Facebook AI Research (FAIR), Ross Wightman, The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DeiT model."""
+
+import collections.abc
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPooling,
+ ImageClassifierOutput,
+ MaskedImageModelingOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging, torch_int
+from ...utils.generic import can_return_tuple, check_model_inputs
+from .configuration_deit import DeiTConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeiTEmbeddings(nn.Module):
+ """
+ Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
+ self.patch_embeddings = DeiTPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing and 2 class embeddings.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1] - 2
+ num_positions = self.position_embeddings.shape[1] - 2
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ class_and_dist_pos_embed = self.position_embeddings[:, :2]
+ patch_pos_embed = self.position_embeddings[:, 2:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_and_dist_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.Tensor:
+ _, _, height, width = pixel_values.shape
+ embeddings = self.patch_embeddings(pixel_values)
+
+ batch_size, seq_length, _ = embeddings.size()
+
+ if bool_masked_pos is not None:
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+
+ distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
+
+ embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
+ position_embedding = self.position_embeddings
+
+ if interpolate_pos_encoding:
+ position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
+
+ embeddings = embeddings + position_embedding
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class DeiTPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ x = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return x
+
+
+# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+
+ # Normalize the attention scores to probabilities.
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ # Mask heads if we want to
+ if attention_mask is not None:
+ attn_weights = attn_weights * attention_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT
+class DeiTSelfAttention(nn.Module):
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.config = config
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.dropout_prob = config.attention_probs_dropout_prob
+ self.scaling = self.attention_head_size**-0.5
+ self.is_causal = False
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ def forward(
+ self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size = hidden_states.shape[0]
+ new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
+
+ key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
+ value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
+ query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ context_layer, attention_probs = attention_interface(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ head_mask,
+ is_causal=self.is_causal,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.dropout_prob,
+ )
+
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.reshape(new_context_layer_shape)
+
+ return context_layer, attention_probs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT
+class DeiTSelfOutput(nn.Module):
+ """
+ The residual connection is defined in DeiTLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->DeiT
+class DeiTAttention(nn.Module):
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ self.attention = DeiTSelfAttention(config)
+ self.output = DeiTSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: set[int]):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ self_attn_output, _ = self.attention(hidden_states, head_mask)
+ output = self.output(self_attn_output, hidden_states)
+ return output
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
+class DeiTIntermediate(nn.Module):
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DeiT
+class DeiTOutput(nn.Module):
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states + input_tensor
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT
+class DeiTLayer(GradientCheckpointingLayer):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = DeiTAttention(config)
+ self.intermediate = DeiTIntermediate(config)
+ self.output = DeiTOutput(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ hidden_states_norm = self.layernorm_before(hidden_states)
+ attention_output = self.attention(hidden_states_norm, head_mask)
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in DeiT, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ return layer_output
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->DeiT
+class DeiTEncoder(nn.Module):
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> BaseModelOutput:
+ for i, layer_module in enumerate(self.layer):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ hidden_states = layer_module(hidden_states, layer_head_mask)
+
+ return BaseModelOutput(last_hidden_state=hidden_states)
+
+
+@auto_docstring
+class DeiTPreTrainedModel(PreTrainedModel):
+ config: DeiTConfig
+ base_model_prefix = "deit"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["DeiTLayer"]
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": DeiTLayer,
+ "attentions": DeiTSelfAttention,
+ }
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, DeiTEmbeddings):
+ module.cls_token.data.zero_()
+ module.position_embeddings.data.zero_()
+ module.distillation_token.data.zero_()
+ if module.mask_token is not None:
+ module.mask_token.data.zero_()
+
+
+@auto_docstring
+class DeiTModel(DeiTPreTrainedModel):
+ def __init__(self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None:
+ r"""
+ add_pooling_layer (bool, *optional*, defaults to `True`):
+ Whether to add a pooling layer
+ use_mask_token (`bool`, *optional*, defaults to `False`):
+ Whether to use a mask token for masked image modeling.
+ """
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = DeiTEmbeddings(config, use_mask_token=use_mask_token)
+ self.encoder = DeiTEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = DeiTPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> DeiTPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ interpolate_pos_encoding: bool = False,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+ """
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
+ expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
+ if pixel_values.dtype != expected_dtype:
+ pixel_values = pixel_values.to(expected_dtype)
+
+ embedding_output = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+ )
+
+ encoder_outputs: BaseModelOutput = self.encoder(embedding_output, head_mask=head_mask)
+ sequence_output = encoder_outputs.last_hidden_state
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ )
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->DeiT
+class DeiTPooler(nn.Module):
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
+ self.activation = ACT2FN[config.pooler_act]
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+@auto_docstring(
+ custom_intro="""
+ DeiT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://huggingface.co/papers/2111.09886).
+
+
+
+ Note that we provide a script to pre-train this model on custom data in our [examples
+ directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
+
+
+ """
+)
+class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
+ def __init__(self, config: DeiTConfig) -> None:
+ super().__init__(config)
+
+ self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True)
+
+ self.decoder = nn.Sequential(
+ nn.Conv2d(
+ in_channels=config.hidden_size,
+ out_channels=config.encoder_stride**2 * config.num_channels,
+ kernel_size=1,
+ ),
+ nn.PixelShuffle(config.encoder_stride),
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ interpolate_pos_encoding: bool = False,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MaskedImageModelingOutput:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+
+ Examples:
+ ```python
+ >>> from transformers import AutoImageProcessor, DeiTForMaskedImageModeling
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
+ >>> model = DeiTForMaskedImageModeling.from_pretrained("facebook/deit-base-distilled-patch16-224")
+
+ >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
+ >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
+ >>> # create random boolean mask of shape (batch_size, num_patches)
+ >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
+
+ >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
+ >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
+ >>> list(reconstructed_pixel_values.shape)
+ [1, 3, 224, 224]
+ ```"""
+
+ outputs: BaseModelOutputWithPooling = self.deit(
+ pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ **kwargs,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ # Reshape to (batch_size, num_channels, height, width)
+ sequence_output = sequence_output[:, 1:-1]
+ batch_size, sequence_length, num_channels = sequence_output.shape
+ height = width = int(sequence_length**0.5)
+ sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
+
+ # Reconstruct pixel values
+ reconstructed_pixel_values = self.decoder(sequence_output)
+
+ masked_im_loss = None
+ if bool_masked_pos is not None:
+ size = self.config.image_size // self.config.patch_size
+ bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
+ mask = (
+ bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
+ .repeat_interleave(self.config.patch_size, 2)
+ .unsqueeze(1)
+ .contiguous()
+ )
+ reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
+ masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
+
+ return MaskedImageModelingOutput(
+ loss=masked_im_loss,
+ reconstruction=reconstructed_pixel_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
+ the [CLS] token) e.g. for ImageNet.
+ """
+)
+class DeiTForImageClassification(DeiTPreTrainedModel):
+ def __init__(self, config: DeiTConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.deit = DeiTModel(config, add_pooling_layer=False)
+
+ # Classifier head
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ interpolate_pos_encoding: bool = False,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> ImageClassifierOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, DeiTForImageClassification
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> # note: we are loading a DeiTForImageClassificationWithTeacher from the hub here,
+ >>> # so the head will be randomly initialized, hence the predictions will be random
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
+ >>> model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the 1000 ImageNet classes
+ >>> predicted_class_idx = logits.argmax(-1).item()
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
+ Predicted class: Polaroid camera, Polaroid Land camera
+ ```"""
+
+ outputs: BaseModelOutputWithPooling = self.deit(
+ pixel_values,
+ head_mask=head_mask,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ **kwargs,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ logits = self.classifier(sequence_output[:, 0, :])
+ # we don't use the distillation token
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config, **kwargs)
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`DeiTForImageClassificationWithTeacher`].
+ """
+)
+class DeiTForImageClassificationWithTeacherOutput(ModelOutput):
+ r"""
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores as the average of the cls_logits and distillation logits.
+ cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
+ class token).
+ distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
+ distillation token).
+ """
+
+ logits: Optional[torch.FloatTensor] = None
+ cls_logits: Optional[torch.FloatTensor] = None
+ distillation_logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of
+ the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.
+
+ .. warning::
+
+ This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
+ supported.
+ """
+)
+class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
+ def __init__(self, config: DeiTConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.deit = DeiTModel(config, add_pooling_layer=False)
+
+ # Classifier heads
+ self.cls_classifier = (
+ nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+ self.distillation_classifier = (
+ nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ interpolate_pos_encoding: bool = False,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> DeiTForImageClassificationWithTeacherOutput:
+ outputs: BaseModelOutputWithPooling = self.deit(
+ pixel_values,
+ head_mask=head_mask,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ **kwargs,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ cls_logits = self.cls_classifier(sequence_output[:, 0, :])
+ distillation_logits = self.distillation_classifier(sequence_output[:, 1, :])
+
+ # during inference, return the average of both classifier predictions
+ logits = (cls_logits + distillation_logits) / 2
+
+ return DeiTForImageClassificationWithTeacherOutput(
+ logits=logits,
+ cls_logits=cls_logits,
+ distillation_logits=distillation_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "DeiTForImageClassification",
+ "DeiTForImageClassificationWithTeacher",
+ "DeiTForMaskedImageModeling",
+ "DeiTModel",
+ "DeiTPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/modeling_tf_deit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/modeling_tf_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c56eee87911edc445641e0bbc14f094e1c5efa7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/modeling_tf_deit.py
@@ -0,0 +1,1232 @@
+# coding=utf-8
+# Copyright 2022 Facebook AI Research (FAIR) and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TensorFlow DeiT model."""
+
+from __future__ import annotations
+
+import collections.abc
+import math
+from dataclasses import dataclass
+
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutput,
+ TFBaseModelOutputWithPooling,
+ TFImageClassifierOutput,
+ TFMaskedImageModelingOutput,
+)
+from ...modeling_tf_utils import (
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list, stable_softmax
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_deit import DeiTConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "DeiTConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/deit-base-distilled-patch16-224"
+_EXPECTED_OUTPUT_SHAPE = [1, 198, 768]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/deit-base-distilled-patch16-224"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+@dataclass
+class TFDeiTForImageClassificationWithTeacherOutput(ModelOutput):
+ """
+ Output type of [`DeiTForImageClassificationWithTeacher`].
+
+ Args:
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores as the average of the cls_logits and distillation logits.
+ cls_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
+ class token).
+ distillation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
+ distillation token).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
+ the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ logits: tf.Tensor | None = None
+ cls_logits: tf.Tensor | None = None
+ distillation_logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+class TFDeiTEmbeddings(keras.layers.Layer):
+ """
+ Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config: DeiTConfig, use_mask_token: bool = False, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.config = config
+ self.use_mask_token = use_mask_token
+ self.patch_embeddings = TFDeiTPatchEmbeddings(config=config, name="patch_embeddings")
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout")
+
+ def build(self, input_shape=None):
+ self.cls_token = self.add_weight(
+ shape=(1, 1, self.config.hidden_size),
+ initializer=keras.initializers.zeros(),
+ trainable=True,
+ name="cls_token",
+ )
+ self.distillation_token = self.add_weight(
+ shape=(1, 1, self.config.hidden_size),
+ initializer=keras.initializers.zeros(),
+ trainable=True,
+ name="distillation_token",
+ )
+ self.mask_token = None
+ if self.use_mask_token:
+ self.mask_token = self.add_weight(
+ shape=(1, 1, self.config.hidden_size),
+ initializer=keras.initializers.zeros(),
+ trainable=True,
+ name="mask_token",
+ )
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = self.add_weight(
+ shape=(1, num_patches + 2, self.config.hidden_size),
+ initializer=keras.initializers.zeros(),
+ trainable=True,
+ name="position_embeddings",
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "patch_embeddings", None) is not None:
+ with tf.name_scope(self.patch_embeddings.name):
+ self.patch_embeddings.build(None)
+ if getattr(self, "dropout", None) is not None:
+ with tf.name_scope(self.dropout.name):
+ self.dropout.build(None)
+
+ def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor:
+ num_patches = embeddings.shape[1] - 2
+ num_positions = self.position_embeddings.shape[1] - 2
+
+ if num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ class_pos_embed = self.position_embeddings[:, 0, :]
+ dist_pos_embed = self.position_embeddings[:, 1, :]
+ patch_pos_embed = self.position_embeddings[:, 2:, :]
+ dim = embeddings.shape[-1]
+ h0 = height // self.config.patch_size
+ w0 = width // self.config.patch_size
+ # # we add a small number to avoid floating point error in the interpolation
+ # # see discussion at https://github.com/facebookresearch/dino/issues/8
+ h0, w0 = h0 + 0.1, w0 + 0.1
+ patch_pos_embed = tf.reshape(
+ patch_pos_embed, (1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+ )
+ patch_pos_embed = tf.image.resize(patch_pos_embed, size=(int(h0), int(w0)), method="bicubic")
+ patch_pos_embed = tf.transpose(patch_pos_embed, perm=[0, 2, 3, 1])
+ patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, dim))
+
+ return tf.concat(
+ [tf.expand_dims(class_pos_embed, axis=0), tf.expand_dims(dist_pos_embed, axis=0), patch_pos_embed], axis=1
+ )
+
+ def call(
+ self,
+ pixel_values: tf.Tensor,
+ bool_masked_pos: tf.Tensor | None = None,
+ training: bool = False,
+ interpolate_pos_encoding: bool = False,
+ ) -> tf.Tensor:
+ _, height, width, _ = pixel_values.shape
+
+ embeddings = self.patch_embeddings(pixel_values)
+ batch_size, seq_length, _ = shape_list(embeddings)
+
+ if bool_masked_pos is not None:
+ mask_tokens = tf.tile(self.mask_token, [batch_size, seq_length, 1])
+ # replace the masked visual tokens by mask_tokens
+ mask = tf.expand_dims(bool_masked_pos, axis=-1)
+ mask = tf.cast(mask, dtype=mask_tokens.dtype)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
+ distillation_tokens = tf.repeat(self.distillation_token, repeats=batch_size, axis=0)
+ embeddings = tf.concat((cls_tokens, distillation_tokens, embeddings), axis=1)
+ position_embedding = self.position_embeddings
+ if interpolate_pos_encoding:
+ position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
+
+ embeddings = embeddings + position_embedding
+ embeddings = self.dropout(embeddings, training=training)
+ return embeddings
+
+
+class TFDeiTPatchEmbeddings(keras.layers.Layer):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config: DeiTConfig, **kwargs) -> None:
+ super().__init__(**kwargs)
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = keras.layers.Conv2D(
+ hidden_size, kernel_size=patch_size, strides=patch_size, name="projection"
+ )
+
+ def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
+ batch_size, height, width, num_channels = shape_list(pixel_values)
+ if tf.executing_eagerly() and num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+
+ x = self.projection(pixel_values)
+ batch_size, height, width, num_channels = shape_list(x)
+ x = tf.reshape(x, (batch_size, height * width, num_channels))
+ return x
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "projection", None) is not None:
+ with tf.name_scope(self.projection.name):
+ self.projection.build([None, None, None, self.num_channels])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfAttention with ViT->DeiT
+class TFDeiTSelfAttention(keras.layers.Layer):
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number "
+ f"of attention heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
+
+ self.query = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
+ )
+ self.key = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
+ )
+ self.value = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
+ )
+ self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
+ self.config = config
+
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
+
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ batch_size = shape_list(hidden_states)[0]
+ mixed_query_layer = self.query(inputs=hidden_states)
+ mixed_key_layer = self.key(inputs=hidden_states)
+ mixed_value_layer = self.value(inputs=hidden_states)
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
+ key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
+ value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ # (batch size, num_heads, seq_len_q, seq_len_k)
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+ dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
+ attention_scores = tf.divide(attention_scores, dk)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(inputs=attention_probs, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = tf.multiply(attention_probs, head_mask)
+
+ attention_output = tf.matmul(attention_probs, value_layer)
+ attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
+
+ # (batch_size, seq_len_q, all_head_size)
+ attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "query", None) is not None:
+ with tf.name_scope(self.query.name):
+ self.query.build([None, None, self.config.hidden_size])
+ if getattr(self, "key", None) is not None:
+ with tf.name_scope(self.key.name):
+ self.key.build([None, None, self.config.hidden_size])
+ if getattr(self, "value", None) is not None:
+ with tf.name_scope(self.value.name):
+ self.value.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfOutput with ViT->DeiT
+class TFDeiTSelfOutput(keras.layers.Layer):
+ """
+ The residual connection is defined in TFDeiTLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTAttention with ViT->DeiT
+class TFDeiTAttention(keras.layers.Layer):
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.self_attention = TFDeiTSelfAttention(config, name="attention")
+ self.dense_output = TFDeiTSelfOutput(config, name="output")
+
+ def prune_heads(self, heads):
+ raise NotImplementedError
+
+ def call(
+ self,
+ input_tensor: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ self_outputs = self.self_attention(
+ hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
+ )
+ attention_output = self.dense_output(
+ hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
+ )
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "self_attention", None) is not None:
+ with tf.name_scope(self.self_attention.name):
+ self.self_attention.build(None)
+ if getattr(self, "dense_output", None) is not None:
+ with tf.name_scope(self.dense_output.name):
+ self.dense_output.build(None)
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->DeiT
+class TFDeiTIntermediate(keras.layers.Layer):
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
+ else:
+ self.intermediate_act_fn = config.hidden_act
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTOutput with ViT->DeiT
+class TFDeiTOutput(keras.layers.Layer):
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+ hidden_states = hidden_states + input_tensor
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.intermediate_size])
+
+
+class TFDeiTLayer(keras.layers.Layer):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.attention = TFDeiTAttention(config, name="attention")
+ self.intermediate = TFDeiTIntermediate(config, name="intermediate")
+ self.deit_output = TFDeiTOutput(config, name="output")
+
+ self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
+ self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
+ self.config = config
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ attention_outputs = self.attention(
+ # in DeiT, layernorm is applied before self-attention
+ input_tensor=self.layernorm_before(inputs=hidden_states, training=training),
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ attention_output = attention_outputs[0]
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in DeiT, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(inputs=hidden_states, training=training)
+
+ intermediate_output = self.intermediate(hidden_states=layer_output, training=training)
+
+ # second residual connection is done here
+ layer_output = self.deit_output(
+ hidden_states=intermediate_output, input_tensor=hidden_states, training=training
+ )
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "intermediate", None) is not None:
+ with tf.name_scope(self.intermediate.name):
+ self.intermediate.build(None)
+ if getattr(self, "deit_output", None) is not None:
+ with tf.name_scope(self.deit_output.name):
+ self.deit_output.build(None)
+ if getattr(self, "layernorm_before", None) is not None:
+ with tf.name_scope(self.layernorm_before.name):
+ self.layernorm_before.build([None, None, self.config.hidden_size])
+ if getattr(self, "layernorm_after", None) is not None:
+ with tf.name_scope(self.layernorm_after.name):
+ self.layernorm_after.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTEncoder with ViT->DeiT
+class TFDeiTEncoder(keras.layers.Layer):
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.layer = [TFDeiTLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ output_hidden_states: bool,
+ return_dict: bool,
+ training: bool = False,
+ ) -> TFBaseModelOutput | tuple[tf.Tensor]:
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_outputs = layer_module(
+ hidden_states=hidden_states,
+ head_mask=head_mask[i],
+ output_attentions=output_attentions,
+ training=training,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
+
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layer", None) is not None:
+ for layer in self.layer:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+@keras_serializable
+class TFDeiTMainLayer(keras.layers.Layer):
+ config_class = DeiTConfig
+
+ def __init__(
+ self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.config = config
+
+ self.embeddings = TFDeiTEmbeddings(config, use_mask_token=use_mask_token, name="embeddings")
+ self.encoder = TFDeiTEncoder(config, name="encoder")
+
+ self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+ self.pooler = TFDeiTPooler(config, name="pooler") if add_pooling_layer else None
+
+ def get_input_embeddings(self) -> TFDeiTPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ raise NotImplementedError
+
+ def get_head_mask(self, head_mask):
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.config.num_hidden_layers
+
+ return head_mask
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ bool_masked_pos: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ interpolate_pos_encoding: bool = False,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor, ...]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # TF 2.0 image layers can't use NCHW format when running on CPU.
+ # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
+ pixel_values = tf.transpose(pixel_values, (0, 2, 3, 1))
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask)
+
+ embedding_output = self.embeddings(
+ pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ training=training,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output, training=training)
+ pooled_output = self.pooler(sequence_output, training=training) if self.pooler is not None else None
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+ return head_outputs + encoder_outputs[1:]
+
+ return TFBaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, None, self.config.hidden_size])
+ if getattr(self, "pooler", None) is not None:
+ with tf.name_scope(self.pooler.name):
+ self.pooler.build(None)
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTPreTrainedModel with ViT->DeiT all-casing
+class TFDeiTPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DeiTConfig
+ base_model_prefix = "deit"
+ main_input_name = "pixel_values"
+
+
+DEIT_START_DOCSTRING = r"""
+ This model is a TensorFlow
+ [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer). Use it as a regular
+ TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and behavior.
+
+ Parameters:
+ config ([`DeiTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DEIT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`DeiTImageProcessor.__call__`] for details.
+
+ head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
+ Whether to interpolate the pre-trained position encodings.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare DeiT Model transformer outputting raw hidden-states without any specific head on top.",
+ DEIT_START_DOCSTRING,
+)
+class TFDeiTModel(TFDeiTPreTrainedModel):
+ def __init__(
+ self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
+ ) -> None:
+ super().__init__(config, **kwargs)
+
+ self.deit = TFDeiTMainLayer(
+ config, add_pooling_layer=add_pooling_layer, use_mask_token=use_mask_token, name="deit"
+ )
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ bool_masked_pos: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ interpolate_pos_encoding: bool = False,
+ training: bool = False,
+ ) -> tuple | TFBaseModelOutputWithPooling:
+ outputs = self.deit(
+ pixel_values=pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ training=training,
+ )
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "deit", None) is not None:
+ with tf.name_scope(self.deit.name):
+ self.deit.build(None)
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTPooler with ViT->DeiT
+class TFDeiTPooler(keras.layers.Layer):
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.pooler_output_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation=config.pooler_act,
+ name="dense",
+ )
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(inputs=first_token_tensor)
+
+ return pooled_output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFDeitPixelShuffle(keras.layers.Layer):
+ """TF layer implementation of torch.nn.PixelShuffle"""
+
+ def __init__(self, upscale_factor: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ if not isinstance(upscale_factor, int) or upscale_factor < 2:
+ raise ValueError(f"upscale_factor must be an integer value >= 2 got {upscale_factor}")
+ self.upscale_factor = upscale_factor
+
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ hidden_states = x
+ batch_size, _, _, num_input_channels = shape_list(hidden_states)
+ block_size_squared = self.upscale_factor**2
+ output_depth = int(num_input_channels / block_size_squared)
+ # When the number of output channels >= 2, PyTorch's PixelShuffle and
+ # TF's depth_to_space differ in their output as the order of channels selected for combining
+ # is a permutation of the other c.f.
+ # https://stackoverflow.com/questions/68272502/tf-depth-to-space-not-same-as-torchs-pixelshuffle-when-output-channels-1
+ permutation = tf.constant(
+ [[i + j * block_size_squared for i in range(block_size_squared) for j in range(output_depth)]]
+ )
+ hidden_states = tf.gather(params=hidden_states, indices=tf.tile(permutation, [batch_size, 1]), batch_dims=-1)
+ hidden_states = tf.nn.depth_to_space(hidden_states, block_size=self.upscale_factor, data_format="NHWC")
+ return hidden_states
+
+
+class TFDeitDecoder(keras.layers.Layer):
+ def __init__(self, config: DeiTConfig, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.conv2d = keras.layers.Conv2D(
+ filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, name="0"
+ )
+ self.pixel_shuffle = TFDeitPixelShuffle(config.encoder_stride, name="1")
+ self.config = config
+
+ def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = inputs
+ hidden_states = self.conv2d(hidden_states)
+ hidden_states = self.pixel_shuffle(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "conv2d", None) is not None:
+ with tf.name_scope(self.conv2d.name):
+ self.conv2d.build([None, None, None, self.config.hidden_size])
+ if getattr(self, "pixel_shuffle", None) is not None:
+ with tf.name_scope(self.pixel_shuffle.name):
+ self.pixel_shuffle.build(None)
+
+
+@add_start_docstrings(
+ "DeiT Model with a decoder on top for masked image modeling, as proposed in"
+ " [SimMIM](https://huggingface.co/papers/2111.09886).",
+ DEIT_START_DOCSTRING,
+)
+class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
+ def __init__(self, config: DeiTConfig) -> None:
+ super().__init__(config)
+
+ self.deit = TFDeiTMainLayer(config, add_pooling_layer=False, use_mask_token=True, name="deit")
+ self.decoder = TFDeitDecoder(config, name="decoder")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ bool_masked_pos: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ interpolate_pos_encoding: bool = False,
+ training: bool = False,
+ ) -> tuple | TFMaskedImageModelingOutput:
+ r"""
+ bool_masked_pos (`tf.Tensor` of type bool and shape `(batch_size, num_patches)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import AutoImageProcessor, TFDeiTForMaskedImageModeling
+ >>> import tensorflow as tf
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
+ >>> model = TFDeiTForMaskedImageModeling.from_pretrained("facebook/deit-base-distilled-patch16-224")
+
+ >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
+ >>> pixel_values = image_processor(images=image, return_tensors="tf").pixel_values
+ >>> # create random boolean mask of shape (batch_size, num_patches)
+ >>> bool_masked_pos = tf.cast(tf.random.uniform((1, num_patches), minval=0, maxval=2, dtype=tf.int32), tf.bool)
+
+ >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
+ >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
+ >>> list(reconstructed_pixel_values.shape)
+ [1, 3, 224, 224]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deit(
+ pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+
+ # Reshape to (batch_size, num_channels, height, width)
+ sequence_output = sequence_output[:, 1:-1]
+ batch_size, sequence_length, num_channels = shape_list(sequence_output)
+ height = width = int(sequence_length**0.5)
+ sequence_output = tf.reshape(sequence_output, (batch_size, height, width, num_channels))
+
+ # Reconstruct pixel values
+ reconstructed_pixel_values = self.decoder(sequence_output, training=training)
+ # TF 2.0 image layers can't use NCHW format when running on CPU, so intermediate layers use NHWC,
+ # including the decoder. We transpose to compute the loss against the pixel values
+ # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width)
+ reconstructed_pixel_values = tf.transpose(reconstructed_pixel_values, (0, 3, 1, 2))
+
+ masked_im_loss = None
+ if bool_masked_pos is not None:
+ size = self.config.image_size // self.config.patch_size
+ bool_masked_pos = tf.reshape(bool_masked_pos, (-1, size, size))
+ mask = tf.repeat(bool_masked_pos, self.config.patch_size, 1)
+ mask = tf.repeat(mask, self.config.patch_size, 2)
+ mask = tf.expand_dims(mask, 1)
+ mask = tf.cast(mask, tf.float32)
+
+ reconstruction_loss = keras.losses.mean_absolute_error(
+ # Swap axes as metric calculation reduces over the final dimension
+ tf.transpose(pixel_values, (1, 2, 3, 0)),
+ tf.transpose(reconstructed_pixel_values, (1, 2, 3, 0)),
+ )
+ reconstruction_loss = tf.expand_dims(reconstruction_loss, 0)
+ total_loss = tf.reduce_sum(reconstruction_loss * mask)
+ num_masked_pixels = (tf.reduce_sum(mask) + 1e-5) * self.config.num_channels
+ masked_im_loss = total_loss / num_masked_pixels
+ masked_im_loss = tf.reshape(masked_im_loss, (1,))
+
+ if not return_dict:
+ output = (reconstructed_pixel_values,) + outputs[1:]
+ return ((masked_im_loss,) + output) if masked_im_loss is not None else output
+
+ return TFMaskedImageModelingOutput(
+ loss=masked_im_loss,
+ reconstruction=reconstructed_pixel_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "deit", None) is not None:
+ with tf.name_scope(self.deit.name):
+ self.deit.build(None)
+ if getattr(self, "decoder", None) is not None:
+ with tf.name_scope(self.decoder.name):
+ self.decoder.build(None)
+
+
+@add_start_docstrings(
+ """
+ DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
+ the [CLS] token) e.g. for ImageNet.
+ """,
+ DEIT_START_DOCSTRING,
+)
+class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: DeiTConfig):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.deit = TFDeiTMainLayer(config, add_pooling_layer=False, name="deit")
+
+ # Classifier head
+ self.classifier = (
+ keras.layers.Dense(config.num_labels, name="classifier")
+ if config.num_labels > 0
+ else keras.layers.Activation("linear", name="classifier")
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ labels: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ interpolate_pos_encoding: bool = False,
+ training: bool = False,
+ ) -> tf.Tensor | TFImageClassifierOutput:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, TFDeiTForImageClassification
+ >>> import tensorflow as tf
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> keras.utils.set_random_seed(3) # doctest: +IGNORE_RESULT
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> # note: we are loading a TFDeiTForImageClassificationWithTeacher from the hub here,
+ >>> # so the head will be randomly initialized, hence the predictions will be random
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
+ >>> model = TFDeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")
+
+ >>> inputs = image_processor(images=image, return_tensors="tf")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the 1000 ImageNet classes
+ >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
+ >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
+ Predicted class: little blue heron, Egretta caerulea
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deit(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.classifier(sequence_output[:, 0, :])
+ # we don't use the distillation token
+
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "deit", None) is not None:
+ with tf.name_scope(self.deit.name):
+ self.deit.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of
+ the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.
+
+ .. warning::
+
+ This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
+ supported.
+ """,
+ DEIT_START_DOCSTRING,
+)
+class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel):
+ def __init__(self, config: DeiTConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.deit = TFDeiTMainLayer(config, add_pooling_layer=False, name="deit")
+
+ # Classifier heads
+ self.cls_classifier = (
+ keras.layers.Dense(config.num_labels, name="cls_classifier")
+ if config.num_labels > 0
+ else keras.layers.Activation("linear", name="cls_classifier")
+ )
+ self.distillation_classifier = (
+ keras.layers.Dense(config.num_labels, name="distillation_classifier")
+ if config.num_labels > 0
+ else keras.layers.Activation("linear", name="distillation_classifier")
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=TFDeiTForImageClassificationWithTeacherOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ interpolate_pos_encoding: bool = False,
+ training: bool = False,
+ ) -> tuple | TFDeiTForImageClassificationWithTeacherOutput:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deit(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+
+ cls_logits = self.cls_classifier(sequence_output[:, 0, :])
+ distillation_logits = self.distillation_classifier(sequence_output[:, 1, :])
+
+ # during inference, return the average of both classifier predictions
+ logits = (cls_logits + distillation_logits) / 2
+
+ if not return_dict:
+ output = (logits, cls_logits, distillation_logits) + outputs[1:]
+ return output
+
+ return TFDeiTForImageClassificationWithTeacherOutput(
+ logits=logits,
+ cls_logits=cls_logits,
+ distillation_logits=distillation_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "deit", None) is not None:
+ with tf.name_scope(self.deit.name):
+ self.deit.build(None)
+ if getattr(self, "cls_classifier", None) is not None:
+ with tf.name_scope(self.cls_classifier.name):
+ self.cls_classifier.build([None, None, self.config.hidden_size])
+ if getattr(self, "distillation_classifier", None) is not None:
+ with tf.name_scope(self.distillation_classifier.name):
+ self.distillation_classifier.build([None, None, self.config.hidden_size])
+
+
+__all__ = [
+ "TFDeiTForImageClassification",
+ "TFDeiTForImageClassificationWithTeacher",
+ "TFDeiTForMaskedImageModeling",
+ "TFDeiTModel",
+ "TFDeiTPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deprecated/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deprecated/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e293c354e1e92a431a601da77d7555f2ecfe29ef
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deprecated/__init__.py
@@ -0,0 +1,49 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .bort import *
+ from .deta import *
+ from .efficientformer import *
+ from .ernie_m import *
+ from .gptsan_japanese import *
+ from .graphormer import *
+ from .jukebox import *
+ from .mctct import *
+ from .mega import *
+ from .mmbt import *
+ from .nat import *
+ from .nezha import *
+ from .open_llama import *
+ from .qdqbert import *
+ from .realm import *
+ from .retribert import *
+ from .speech_to_text_2 import *
+ from .tapex import *
+ from .trajectory_transformer import *
+ from .transfo_xl import *
+ from .tvlt import *
+ from .van import *
+ from .vit_hybrid import *
+ from .xlm_prophetnet import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d738fbc087888597da19735271366d4e35ab708c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_dia import *
+ from .feature_extraction_dia import *
+ from .generation_dia import *
+ from .modeling_dia import *
+ from .processing_dia import *
+ from .tokenization_dia import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/configuration_dia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/configuration_dia.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4dec60b3e4853574e4d528e7b641507a8c0b414
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/configuration_dia.py
@@ -0,0 +1,376 @@
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Dia model configuration"""
+
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DiaEncoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DiaEncoder`]. It is used to instantiate a Dia
+ encoder according to the specified arguments, defining the encoder architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 16):
+ Number of key and value heads for each attention layer in the Transformer encoder.
+ head_dim (`int`, *optional*, defaults to 128):
+ Dimensionality of the attention head.
+ intermediate_size (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the normalization layers.
+ vocab_size (`int`, *optional*, defaults to 256):
+ Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`DiaModel`].
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"swish"` and `"gelu_new"` are supported.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ """
+
+ model_type = "dia_encoder"
+
+ def __init__(
+ self,
+ max_position_embeddings: int = 1024,
+ num_hidden_layers: int = 12,
+ hidden_size: int = 1024,
+ num_attention_heads: int = 16,
+ num_key_value_heads: int = 16,
+ head_dim: int = 128,
+ intermediate_size: int = 4096,
+ norm_eps: float = 1e-5,
+ vocab_size: int = 256,
+ hidden_act: str = "silu",
+ rope_theta: float = 10000.0,
+ rope_scaling: Optional[dict] = None,
+ initializer_range: float = 0.02,
+ **kwargs,
+ ):
+ self.max_position_embeddings = max_position_embeddings
+ self.num_hidden_layers = num_hidden_layers
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ self.norm_eps = norm_eps
+ self.vocab_size = vocab_size
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+ self.initializer_range = initializer_range
+ super().__init__(**kwargs)
+
+
+class DiaDecoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DiaDecoder`]. It is used to instantiate a Dia
+ decoder according to the specified arguments, defining the decoder architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ max_position_embeddings (`int`, *optional*, defaults to 3072):
+ The maximum sequence length that this model might ever be used with.
+ num_hidden_layers (`int`, *optional*, defaults to 18):
+ Number of hidden layers in the Transformer decoder.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimensionality of the decoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 8192):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 4):
+ Number of key and value heads for each attention layer in the Transformer decoder.
+ head_dim (`int`, *optional*, defaults to 128):
+ Dimensionality of the attention head.
+ cross_num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each cross-attention layer in the Transformer decoder.
+ cross_head_dim (`int`, *optional*, defaults to 128):
+ Dimensionality of the cross-attention head.
+ cross_num_key_value_heads (`int`, *optional*, defaults to 16):
+ Number of key and value heads for each cross-attention layer in the Transformer decoder.
+ cross_hidden_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the cross-attention layers.
+ norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the normalization layers.
+ vocab_size (`int`, *optional*, defaults to 1028):
+ Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`DiaModel`].
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, `"relu"`,
+ `"swish"` and `"gelu_new"` are supported.
+ num_channels (`int`, *optional*, defaults to 9):
+ Number of channels for the Dia decoder.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+ Indicating that this model is part of an encoder-decoder architecture.
+ """
+
+ model_type = "dia_decoder"
+
+ def __init__(
+ self,
+ max_position_embeddings: int = 3072,
+ num_hidden_layers: int = 18,
+ hidden_size: int = 2048,
+ intermediate_size: int = 8192,
+ num_attention_heads: int = 16,
+ num_key_value_heads: int = 4,
+ head_dim: int = 128,
+ cross_num_attention_heads: int = 16,
+ cross_head_dim: int = 128,
+ cross_num_key_value_heads: int = 16,
+ cross_hidden_size: int = 1024,
+ norm_eps: float = 1e-5,
+ vocab_size: int = 1028,
+ hidden_act: str = "silu",
+ num_channels: int = 9,
+ rope_theta: float = 10000.0,
+ rope_scaling: Optional[dict] = None,
+ initializer_range: float = 0.02,
+ use_cache: bool = True,
+ is_encoder_decoder: bool = True,
+ **kwargs,
+ ):
+ self.max_position_embeddings = max_position_embeddings
+ self.num_hidden_layers = num_hidden_layers
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.head_dim = head_dim
+ self.cross_num_key_value_heads = cross_num_key_value_heads
+ self.cross_num_attention_heads = cross_num_attention_heads
+ self.cross_head_dim = cross_head_dim
+ self.cross_hidden_size = cross_hidden_size
+ self.norm_eps = norm_eps
+ self.vocab_size = vocab_size
+ self.hidden_act = hidden_act
+ self.num_channels = num_channels
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+ self.initializer_range = initializer_range
+ self.use_cache = use_cache
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
+
+
+class DiaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DiaModel`]. It is used to instantiate a
+ Dia model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the
+ [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ encoder_config (`DiaEncoderConfig`, *optional*):
+ Configuration for the encoder part of the model. If not provided, a default `DiaEncoderConfig` will be used.
+ decoder_config (`DiaDecoderConfig`, *optional*):
+ Configuration for the decoder part of the model. If not provided, a default `DiaDecoderConfig` will be used.
+ norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the normalization layers.
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+ Indicating that this model uses an encoder-decoder architecture.
+ pad_token_id (`int`, *optional*, defaults to 1025):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 1024):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 1026):
+ Beginning of stream token id.
+ delay_pattern (`list[int]`, *optional*, defaults to `[0, 8, 9, 10, 11, 12, 13, 14, 15]`):
+ The delay pattern for the decoder. The length of this list must match `decoder_config.num_channels`.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+
+ Example:
+
+ ```python
+ >>> from transformers import DiaConfig, DiaModel
+
+ >>> # Initializing a DiaConfig with default values
+ >>> configuration = DiaConfig()
+
+ >>> # Initializing a DiaModel (with random weights) from the configuration
+ >>> model = DiaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "dia"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ sub_configs = {"encoder_config": DiaEncoderConfig, "decoder_config": DiaDecoderConfig}
+
+ def __init__(
+ self,
+ encoder_config: Optional[DiaEncoderConfig] = None,
+ decoder_config: Optional[DiaDecoderConfig] = None,
+ norm_eps: float = 1e-5,
+ is_encoder_decoder: bool = True,
+ pad_token_id: int = 1025,
+ eos_token_id: int = 1024,
+ bos_token_id: int = 1026,
+ delay_pattern: Optional[list[int]] = None,
+ initializer_range: float = 0.02,
+ use_cache: bool = True,
+ **kwargs,
+ ):
+ if isinstance(encoder_config, dict):
+ encoder_config = DiaEncoderConfig(**encoder_config)
+ if isinstance(decoder_config, dict):
+ decoder_config = DiaDecoderConfig(**decoder_config)
+ self.encoder_config = encoder_config if encoder_config is not None else DiaEncoderConfig()
+ self.decoder_config = decoder_config if decoder_config is not None else DiaDecoderConfig()
+ self.norm_eps = norm_eps
+ self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 8, 9, 10, 11, 12, 13, 14, 15]
+ self.initializer_range = initializer_range
+ self.use_cache = use_cache
+
+ assert self.decoder_config.num_channels == len(self.delay_pattern), (
+ "Number of channels must match delay pattern length."
+ )
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ bos_token_id=bos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ **kwargs,
+ )
+
+ def get_text_config(self, *args, **kwargs):
+ """Defaulting to audio config as it's the decoder in this case which is usually the text backbone"""
+ return self.decoder_config
+
+
+__all__ = ["DiaConfig", "DiaEncoderConfig", "DiaDecoderConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/feature_extraction_dia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/feature_extraction_dia.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4376b773b27365932774a35746b2928cf0af707
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/feature_extraction_dia.py
@@ -0,0 +1,183 @@
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for Dia"""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
+from ...feature_extraction_utils import BatchFeature
+from ...utils import PaddingStrategy, TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DiaFeatureExtractor(SequenceFeatureExtractor):
+ r"""
+ Constructs an Dia feature extractor.
+
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
+
+ Args:
+ feature_size (`int`, *optional*, defaults to 1):
+ The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
+ sampling_rate (`int`, *optional*, defaults to 16000):
+ The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
+ padding_value (`float`, *optional*, defaults to 0.0):
+ The value that is used for padding.
+ hop_length (`int`, *optional*, defaults to 512):
+ Overlap length between successive windows.
+ """
+
+ model_input_names = ["input_values", "n_quantizers"]
+
+ def __init__(
+ self,
+ feature_size: int = 1,
+ sampling_rate: int = 16000,
+ padding_value: float = 0.0,
+ hop_length: int = 512,
+ **kwargs,
+ ):
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
+ self.hop_length = hop_length
+
+ def __call__(
+ self,
+ raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
+ padding: Optional[Union[bool, str, PaddingStrategy]] = None,
+ truncation: Optional[bool] = False,
+ max_length: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ sampling_rate: Optional[int] = None,
+ ) -> BatchFeature:
+ """
+ Main method to featurize and prepare for the model one or several sequence(s).
+
+ Args:
+ raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
+ The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
+ values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
+ `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
+ (`feature_size = 2`).
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, *optional*, defaults to `False`):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ sampling_rate (`int`, *optional*):
+ The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
+ `sampling_rate` at the forward call to prevent silent errors.
+ """
+ if sampling_rate is not None:
+ if sampling_rate != self.sampling_rate:
+ raise ValueError(
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
+ f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
+ f" {self.sampling_rate} and not {sampling_rate}."
+ )
+ else:
+ logger.warning(
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
+ "Failing to do so can result in silent errors that might be hard to debug."
+ )
+
+ if padding and truncation:
+ raise ValueError("Both padding and truncation were set. Make sure you only set one.")
+ elif padding is None:
+ # by default let's pad the inputs
+ padding = True
+
+ is_batched = bool(
+ isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
+ )
+
+ if is_batched:
+ raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
+ elif not is_batched and not isinstance(raw_audio, np.ndarray):
+ raw_audio = np.asarray(raw_audio, dtype=np.float32)
+ elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
+ raw_audio = raw_audio.astype(np.float32)
+
+ # always return batch
+ if not is_batched:
+ raw_audio = [np.asarray(raw_audio).T]
+
+ # convert stereo to mono if necessary, unique to Dia
+ for idx, example in enumerate(raw_audio):
+ if self.feature_size == 2 and example.ndim == 2:
+ raw_audio[idx] = np.mean(example, -1)
+
+ # verify inputs are valid
+ for idx, example in enumerate(raw_audio):
+ if example.ndim > 2:
+ raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
+ if self.feature_size == 1 and example.ndim != 1:
+ raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
+ if self.feature_size == 2 and example.ndim != 1: # note the conversion before
+ raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
+
+ input_values = BatchFeature({"input_values": raw_audio})
+
+ # temporarily treat it as if we were mono as we also convert stereo to mono
+ original_feature_size = self.feature_size
+ self.feature_size = 1
+
+ # normal padding on batch
+ padded_inputs = self.pad(
+ input_values,
+ max_length=max_length,
+ truncation=truncation,
+ padding=padding,
+ return_attention_mask=True,
+ pad_to_multiple_of=self.hop_length,
+ )
+ padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
+
+ input_values = []
+ for example in padded_inputs.pop("input_values"):
+ if self.feature_size == 1:
+ example = example[..., None]
+ input_values.append(example.T)
+
+ padded_inputs["input_values"] = input_values
+ if return_tensors is not None:
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
+
+ # rewrite back to original feature size
+ self.feature_size = original_feature_size
+
+ return padded_inputs
+
+
+__all__ = ["DiaFeatureExtractor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/generation_dia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/generation_dia.py
new file mode 100644
index 0000000000000000000000000000000000000000..c297de7203d4e5b30a189047233976d310179907
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/generation_dia.py
@@ -0,0 +1,463 @@
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+
+from ...generation.logits_process import (
+ DiaClassifierFreeGuidanceLogitsProcessor,
+ DiaEOSChannelFilterLogitsProcessor,
+ DiaEOSDelayPatternLogitsProcessor,
+ LogitsProcessorList,
+ TemperatureLogitsWarper,
+)
+from ...generation.stopping_criteria import StoppingCriteriaList
+from ...generation.streamers import BaseStreamer
+from ...generation.utils import GenerateOutput, GenerationConfig, GenerationMixin, GenerationMode
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...integrations.fsdp import is_fsdp_managed_module
+from ...modeling_utils import PreTrainedModel
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DiaGenerationMixin(GenerationMixin):
+ # Indicates CFG which needs preparation to be properly handled by repeats
+ _uses_cfg = None
+
+ def _get_logits_processor(
+ self,
+ generation_config: GenerationConfig,
+ input_ids_seq_length: Optional[int] = None,
+ encoder_input_ids: Optional[torch.LongTensor] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ device: Optional[str] = None,
+ model_kwargs: Optional[dict[str, Any]] = None,
+ negative_prompt_ids: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ ) -> LogitsProcessorList:
+ # Need either custom order or custom processor instead
+ # (Temporarily disabling those for the super function)
+ original_guidance_scale = generation_config.guidance_scale
+ original_temperature = generation_config.temperature
+ generation_config.guidance_scale = None
+ generation_config.temperature = None
+
+ # Get base processors and those we can integrate easily
+ custom_processors = LogitsProcessorList()
+
+ if original_temperature is not None and original_temperature != 1.0:
+ custom_processors.append(TemperatureLogitsWarper(original_temperature))
+
+ custom_processors.append(
+ DiaEOSChannelFilterLogitsProcessor(
+ num_channels=len(self.config.delay_pattern),
+ eos_token_id=self.config.eos_token_id,
+ )
+ )
+
+ merged_processors = super()._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=encoder_input_ids,
+ prefix_allowed_tokens_fn=None,
+ logits_processor=custom_processors,
+ device=device,
+ model_kwargs=model_kwargs,
+ negative_prompt_ids=negative_prompt_ids,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ # Custom processors we need at specific positions
+ if original_guidance_scale is not None and original_guidance_scale != 1:
+ cfg_processor = DiaClassifierFreeGuidanceLogitsProcessor(
+ guidance_scale=original_guidance_scale,
+ guidance_top_k=generation_config.top_k,
+ )
+ merged_processors.insert(0, cfg_processor)
+
+ merged_processors.append(
+ DiaEOSDelayPatternLogitsProcessor(
+ delay_pattern=self.config.delay_pattern,
+ eos_token_id=self.config.eos_token_id,
+ max_generation_len=generation_config.max_length,
+ device=device,
+ )
+ )
+
+ # Enable temporarily disabled values back
+ generation_config.guidance_scale = original_guidance_scale
+ generation_config.temperature = original_temperature
+
+ return merged_processors
+
+ def _prepare_generation_config(
+ self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Any
+ ) -> tuple[GenerationConfig, dict]:
+ generation_config, model_kwargs = super()._prepare_generation_config(
+ generation_config, use_model_defaults, **kwargs
+ )
+
+ # We allow generation up to max length + max delay pattern
+ # (will revert back to max length after generation)
+ generation_config.max_length += max(self.config.delay_pattern)
+
+ # Internal flag to indicate CFG that needs to prepare unconditioned input
+ self._uses_cfg = generation_config.guidance_scale is not None and generation_config.guidance_scale != 1
+
+ return generation_config, model_kwargs
+
+ def _prepare_model_inputs(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ bos_token_id: Optional[torch.Tensor] = None,
+ model_kwargs: Optional[dict[str, torch.Tensor]] = None,
+ ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
+ inputs, input_name, model_kwargs = super()._prepare_model_inputs(
+ inputs=inputs,
+ bos_token_id=bos_token_id,
+ model_kwargs=model_kwargs,
+ )
+
+ # If CFG is requested we fill in the unconditioned parts
+ if self._uses_cfg:
+ unconditioned_inputs = torch.zeros_like(inputs)
+ inputs = torch.cat([inputs, unconditioned_inputs], dim=0)
+
+ if model_kwargs.get("attention_mask", None) is not None:
+ model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat(2, 1)
+
+ return inputs, input_name, model_kwargs
+
+ def _prepare_decoder_input_ids_for_generation(
+ self,
+ batch_size: int,
+ model_input_name: str,
+ model_kwargs: dict[str, torch.Tensor],
+ decoder_start_token_id: torch.Tensor,
+ device: Optional[torch.device] = None,
+ ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
+ """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
+ # 1. Check whether the user has defined `decoder_input_ids` and `decoder_attention_mask`; if not error out
+ decoder_input_ids = decoder_attention_mask = None
+ if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
+ decoder_input_ids = model_kwargs.pop("decoder_input_ids")
+ if model_kwargs is not None and "decoder_attention_mask" in model_kwargs:
+ decoder_attention_mask = model_kwargs.pop("decoder_attention_mask")
+
+ # We allow generating without preparation (no proper delay) but discourage it
+ if decoder_input_ids is None or decoder_attention_mask is None:
+ logger.warning_once(
+ "In order to generate with Dia, we need the processed audio input: Got `decoder_input_ids`:"
+ f" {decoder_input_ids is not None} and got `decoder_attention_mask`={decoder_attention_mask is not None}."
+ f" This can be achieved via the [`DiaProcessor`] but now defaulting to non-delayed generation."
+ )
+
+ num_channels = self.config.decoder_config.num_channels
+ real_batch_size = batch_size // 2 if self._uses_cfg else batch_size
+
+ if decoder_input_ids is None:
+ decoder_input_ids = torch.full(
+ (real_batch_size, 1, num_channels), decoder_start_token_id, dtype=torch.long, device=device
+ )
+
+ decoder_attention_mask = torch.ones(
+ size=(real_batch_size, decoder_input_ids.shape[1]), dtype=torch.long, device=device
+ )
+
+ # 2. Determine the valid input and what works as mask within the input
+ delay_mask = decoder_input_ids.long()
+ valid_input_size = (
+ decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == self.config.pad_token_id).sum(dim=-1).max()
+ )
+ decoder_input_ids = delay_mask[:, :valid_input_size].transpose(1, 2).long()
+ decoder_attention_mask = decoder_attention_mask[:, :valid_input_size].long()
+
+ # 3. Overwrite into model kwargs
+ model_kwargs["decoder_attention_mask"] = decoder_attention_mask
+ model_kwargs["decoder_delay_mask"] = delay_mask
+
+ return decoder_input_ids, model_kwargs
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ encoder_outputs=None, # Using this to easily get the batch size
+ decoder_delay_mask=None,
+ **kwargs,
+ ):
+ # Reshape decoder input_ids to 3D to be compile friendly and to fit the expected model input shape
+ batch_size = encoder_outputs[0].shape[0] // 2 if self._uses_cfg else encoder_outputs[0].shape[0]
+ input_ids = input_ids.reshape(batch_size, self.config.decoder_config.num_channels, -1).transpose(1, 2)
+
+ # Base method handles most things except CFG and the delay pattern mask
+ model_inputs = super().prepare_inputs_for_generation(input_ids, encoder_outputs=encoder_outputs, **kwargs)
+
+ # Post processing for CFG and overwriting via delay pattern mask
+ # 1. Delay pattern mask -- force tokens if not allowed to predict (!= pad_token in mask)
+ model_inputs["decoder_input_ids"] = self.apply_delay_mask(
+ input_ids, self.config.pad_token_id, decoder_delay_mask
+ )
+
+ # Depending on cache usage we need to pass all or just one
+ if model_inputs.get("use_cache", False) and model_inputs["cache_position"][0] > 0:
+ model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"][:, -1, :][:, None, :]
+
+ # Be compile friendly
+ model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"].contiguous()
+
+ # 2. Apply CFG duplication if needed
+ if self._uses_cfg:
+ for key in ["decoder_input_ids", "decoder_attention_mask", "decoder_position_ids"]:
+ if model_inputs.get(key, None) is not None:
+ # double first dimension and keep everything else the same
+ repeat_pattern = tuple([2] + [1] * (model_inputs[key].ndim - 1))
+ model_inputs[key] = model_inputs[key].repeat(*repeat_pattern)
+
+ return model_inputs
+
+ @staticmethod
+ def apply_delay_mask(input_ids: torch.Tensor, pad_id: int, delay_mask: Optional[torch.Tensor]) -> torch.Tensor:
+ if delay_mask is None:
+ return input_ids
+
+ mask_len = min(input_ids.shape[1], delay_mask.shape[1])
+ valid_mask = delay_mask[:, :mask_len, :]
+ valid_input = input_ids[:, :mask_len, :]
+
+ # Overwrite the respective parts of the input
+ input_ids[:, :mask_len, :] = torch.where(valid_mask == pad_id, valid_input, valid_mask)
+
+ return input_ids
+
+ def _main_generate_loop(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
+ synced_gpus: Optional[bool] = None,
+ assistant_model: Optional["PreTrainedModel"] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ negative_prompt_ids: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ use_model_defaults: Optional[bool] = None,
+ custom_generate: Optional[str] = None,
+ **kwargs,
+ ):
+ # ********** mostly taken from main generate function up to calling the different methods (see NOTE) **********
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
+ generation_mode_kwargs = self._extract_generation_mode_kwargs(
+ custom_generate,
+ kwargs,
+ synced_gpus,
+ assistant_model,
+ streamer,
+ )
+ generation_config, model_kwargs = self._prepare_generation_config(
+ generation_config, use_model_defaults, **kwargs
+ )
+ generation_mode = generation_config.get_generation_mode(assistant_model)
+
+ self._validate_model_kwargs(model_kwargs.copy())
+ self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)
+
+ # 2. Set generation parameters if not already defined
+ if synced_gpus is None:
+ synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
+
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ # 3. Define model inputs
+ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ batch_size = inputs_tensor.shape[0]
+
+ device = inputs_tensor.device
+ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
+
+ # 4. Define other model kwargs
+ if "encoder_outputs" not in model_kwargs:
+ # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
+ inputs_tensor, model_kwargs, model_input_name, generation_config
+ )
+
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
+ batch_size=batch_size,
+ model_input_name=model_input_name,
+ model_kwargs=model_kwargs,
+ decoder_start_token_id=generation_config._decoder_start_token_tensor,
+ device=inputs_tensor.device,
+ )
+
+ if generation_config.token_healing:
+ input_ids = self.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer"))
+
+ if streamer is not None:
+ streamer.put(input_ids.cpu())
+
+ # 6. Prepare `max_length` depending on other stopping criteria.
+ # NOTE: incorrect `input_ids.shape[1]` previously
+ input_ids_length = input_ids.shape[-1]
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
+ generation_config = self._prepare_generated_length(
+ generation_config=generation_config,
+ has_default_max_length=has_default_max_length,
+ has_default_min_length=has_default_min_length,
+ model_input_name=model_input_name,
+ inputs_tensor=inputs_tensor,
+ input_ids_length=input_ids_length,
+ )
+
+ # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
+ # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
+ # dynamically overrides this value as it can need more than the last token logits
+ if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
+ model_kwargs["logits_to_keep"] = 1
+
+ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
+
+ # 7. Prepare the cache.
+ # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
+ # - different models have a different cache name expected by the model (default = "past_key_values")
+ # - `max_length`, prepared above, is used to determine the maximum cache length
+ max_cache_length = generation_config.max_length - 1
+ if (
+ inputs_tensor.shape[1] != input_ids_length
+ and model_input_name == "inputs_embeds"
+ and not self.config.is_encoder_decoder
+ ):
+ max_cache_length += inputs_tensor.shape[1]
+ self._prepare_cache_for_generation(
+ generation_config, model_kwargs, generation_mode, batch_size, max_cache_length
+ )
+
+ # 8. prepare logits processors and stopping criteria
+ prepared_logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_length,
+ encoder_input_ids=inputs_tensor,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ logits_processor=logits_processor,
+ device=inputs_tensor.device,
+ model_kwargs=model_kwargs,
+ negative_prompt_ids=negative_prompt_ids,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+ prepared_stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config,
+ stopping_criteria=stopping_criteria,
+ tokenizer=generation_mode_kwargs.get("tokenizer"),
+ )
+
+ # Set model_kwargs `use_cache` so we can use it later in forward runs
+ model_kwargs["use_cache"] = generation_config.use_cache
+ # ******************* taken from main generate function up to calling the different methods *******************
+
+ # Prepare inner 2D logic in generation loop
+ input_ids = input_ids.reshape(-1, input_ids.shape[-1])
+
+ # 10. go into different generation modes
+ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
+ # 11. expand input_ids with `num_return_sequences` additional sequences per batch
+ if generation_config.num_return_sequences > 1:
+ raise ValueError("`num_return_sequences>1` is incompatible with Dia.")
+
+ # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
+ return self._sample(
+ input_ids,
+ logits_processor=prepared_logits_processor,
+ stopping_criteria=prepared_stopping_criteria,
+ generation_config=generation_config,
+ **generation_mode_kwargs,
+ **model_kwargs,
+ )
+ else:
+ raise ValueError(
+ "Got incompatible mode for generation, should be one of greedy or sampling. "
+ "Ensure that beam search is de-activated by setting `num_beams=1`."
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
+ synced_gpus: Optional[bool] = None,
+ assistant_model: Optional["PreTrainedModel"] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ negative_prompt_ids: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ use_model_defaults: Optional[bool] = None,
+ custom_generate: Optional[str] = None,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ # We expect the initial input ids to be the complete mask (delayed input)
+ delay_mask = kwargs.get("decoder_input_ids")
+ if delay_mask is not None:
+ delay_mask = delay_mask.clone()
+
+ output = self._main_generate_loop(
+ inputs=inputs,
+ generation_config=generation_config,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ synced_gpus=synced_gpus,
+ assistant_model=assistant_model,
+ streamer=streamer,
+ negative_prompt_ids=negative_prompt_ids,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ use_model_defaults=use_model_defaults,
+ custom_generate=custom_generate,
+ **kwargs,
+ )
+
+ return_dict_in_generate = not isinstance(output, torch.Tensor)
+
+ if return_dict_in_generate:
+ output_sequences = output.sequences
+ else:
+ output_sequences = output
+
+ # Reshape from 2D (bsz * channels, seq_len) to 3D (bsz, seq_len, channels)
+ num_channels = self.config.decoder_config.num_channels
+ bsz = output_sequences.shape[0] // num_channels
+ output_sequences = output_sequences.reshape(bsz, num_channels, -1).transpose(1, 2)
+
+ # Apply delay mask
+ output_sequences = self.apply_delay_mask(output_sequences, self.config.pad_token_id, delay_mask)
+
+ if return_dict_in_generate:
+ output.sequences = output_sequences
+ else:
+ output = output_sequences
+
+ return output
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/modeling_dia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/modeling_dia.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf662b224aabd884521d7e14b8a167886377f4b5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/modeling_dia.py
@@ -0,0 +1,958 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/dia/modular_dia.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_dia.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ TransformersKwargs,
+ auto_docstring,
+ can_return_tuple,
+ is_torch_flex_attn_available,
+ is_torchdynamo_compiling,
+ logging,
+)
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
+from .generation_dia import DiaGenerationMixin
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+@auto_docstring
+class DiaPreTrainedModel(PreTrainedModel):
+ config: DiaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _can_compile_fullgraph = True
+ main_input_name = "input_ids"
+ _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
+
+
+class DiaMultiChannelEmbedding(nn.Module):
+ """In order to efficiently compute the audio embedding from the 9 different channels,
+ we vectorize the embedding process by using a single embedding layer and an offset.
+ Example:
+ - num_embeds = 4
+ - vocab_size = 8
+ - num_channels = 3
+ We would have offsets = [0, 8, 16]
+ If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
+ then tokens = audio_codes + offsets
+ = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
+ This allows us to use a single embedding layer for all channels.
+ """
+
+ def __init__(self, config: DiaDecoderConfig):
+ super().__init__()
+ self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
+ self.hidden_size = config.hidden_size
+ self.num_channels = config.num_channels
+ offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,)
+ self.register_buffer("offsets", offsets, persistent=False)
+
+ def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
+ tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
+ embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
+ return embeds.sum(dim=2)
+
+
+class DiaMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ up_states = self.gate_up_proj(hidden_states)
+
+ gate, up_states = up_states.chunk(2, dim=-1)
+ up_states = up_states * self.activation_fn(gate)
+
+ return self.down_proj(up_states)
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class DiaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ DiaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class DiaRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: DiaConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class DiaSelfAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+ self.num_heads = self.config.num_attention_heads
+ self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
+ self.scaling = 1
+ self.attention_dropout = 0.0
+ self.is_causal = is_causal
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class DiaCrossAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: DiaDecoderConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+ self.cross_hidden_size = config.cross_hidden_size
+ self.num_heads = self.config.cross_num_attention_heads
+ self.num_key_value_heads = self.config.cross_num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.head_dim = config.cross_head_dim
+ self.scaling = 1
+ self.attention_dropout = 0.0
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
+ if past_key_values is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
+ value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
+ else:
+ key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
+ value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
+
+ if past_key_values is not None:
+ # save all states to the cache
+ key_states, value_states = past_key_values.cross_attention_cache.update(
+ key_states,
+ value_states,
+ self.layer_idx,
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ past_key_values.is_updated[self.layer_idx] = True
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class DiaEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DiaEncoderConfig, layer_idx: int):
+ super().__init__()
+ self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
+ self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.mlp = DiaMLP(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ residual = hidden_states
+ normed_states = self.pre_sa_norm(hidden_states)
+ self_attn_output, self_attn_weights = self.self_attention(
+ normed_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = residual + self_attn_output
+
+ residual = hidden_states
+ normed_states = self.post_sa_norm(hidden_states)
+ mlp_out = self.mlp(normed_states)
+ hidden_states = residual + mlp_out
+
+ return hidden_states, self_attn_weights
+
+
+class DiaEncoder(DiaPreTrainedModel):
+ def __init__(self, config: DiaEncoderConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.layers = nn.ModuleList(
+ [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.rotary_embeddings = DiaRotaryEmbedding(config)
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[BaseModelOutput, tuple]:
+ hidden_states = self.embedding(input_ids)
+
+ # RoPE
+ # Note: We expect right padding and hence always generate
+ # the position ids on the fly to reduce preparation overhead
+ position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
+ position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ layer_outputs = encoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states += (hidden_states,)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+
+class DiaDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DiaDecoderConfig, layer_idx: int):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
+ self.cross_attention = DiaCrossAttention(config, layer_idx)
+ self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.mlp = DiaMLP(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ self_attn_cache = past_key_values
+ if isinstance(self_attn_cache, EncoderDecoderCache):
+ self_attn_cache = self_attn_cache.self_attention_cache
+
+ residual = hidden_states
+ normed_states = self.pre_sa_norm(hidden_states)
+ self_attn_output, self_attn_weights = self.self_attention(
+ normed_states,
+ position_embeddings,
+ attention_mask,
+ # Needs to be an arg in order to function properly
+ # on inplace operations to be carried (e.g. compile)
+ self_attn_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + self_attn_output
+
+ residual = hidden_states
+ normed_states = self.pre_ca_norm(hidden_states)
+ cross_states, cross_attn_weights = self.cross_attention(
+ normed_states,
+ encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ **kwargs,
+ )
+ hidden_states = residual + cross_states
+
+ residual = hidden_states
+ normed_states = self.pre_mlp_norm(hidden_states)
+ mlp_out = self.mlp(normed_states)
+ hidden_states = residual + mlp_out
+
+ return hidden_states, self_attn_weights, cross_attn_weights
+
+
+class DiaDecoder(DiaPreTrainedModel):
+ """Transformer Decoder Stack using DenseGeneral."""
+
+ def __init__(self, config: DiaDecoderConfig):
+ super().__init__(config)
+ self.num_channels = config.num_channels
+ self.vocab_size = config.vocab_size
+ self.embeddings = DiaMultiChannelEmbedding(config)
+ self.rotary_embeddings = DiaRotaryEmbedding(config)
+ self.layers = nn.ModuleList(
+ [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
+ The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+
+ batch_size, seq_length = input_ids.size()[:-1]
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_key_values_length, past_key_values_length + seq_length, device=input_ids.device
+ )
+ if position_ids is None:
+ position_ids = cache_position[None, :]
+
+ # RoPE
+ hidden_states = self.embeddings(input_ids)
+ position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
+
+ if attention_mask is None and not is_torchdynamo_compiling():
+ # required mask seq length can be calculated via length of past cache
+ mask_seq_length = past_key_values_length + seq_length
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
+
+ attention_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=hidden_states,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ hidden_states.shape[:2],
+ hidden_states,
+ )
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ for layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = layer(
+ hidden_states,
+ position_embeddings,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns = all_self_attns + (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
+
+@auto_docstring(
+ custom_intro="""
+ The bare Dia model outputting raw hidden-states without any specific head on top.
+ """
+)
+class DiaModel(DiaPreTrainedModel):
+ def __init__(self, config: DiaConfig):
+ super().__init__(config)
+ self.config = config
+ self.encoder = DiaEncoder(config.encoder_config)
+ self.decoder = DiaDecoder(config.decoder_config)
+ self.post_init()
+
+ def get_encoder(self):
+ return self.encoder
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_position_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple, Seq2SeqModelOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
+ or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
+ 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
+ the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
+ tened audio logits which are used to calculate the loss.
+
+ 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
+ Dia to calculate embeddings and subsequent steps more efficiently.
+
+ If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
+ `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
+ [`DiaProcessor.__call__`] for more details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
+
+ [What are position IDs?](../glossary#position-ids)
+ """
+
+ if input_ids is None and encoder_outputs is None:
+ raise ValueError(
+ "You should either provide text ids or the cached text encodings. Neither has been found."
+ )
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if self.is_gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if use_cache and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ **kwargs,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
+ elif not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # On default we initialize the decoder with bos tokens if nothing has been provided
+ bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
+ if decoder_input_ids is None:
+ decoder_input_ids = torch.full(
+ size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device
+ )
+ # Ensure 3D
+ if decoder_input_ids.ndim == 2:
+ decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
+
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ position_ids=decoder_position_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs[0],
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
+ """
+)
+class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
+ base_model_prefix = "model"
+
+ def __init__(self, config: DiaConfig):
+ super().__init__(config)
+ self.config = config
+ self.model = DiaModel(config)
+
+ self.num_channels = config.decoder_config.num_channels
+ self.vocab_size = config.decoder_config.vocab_size
+ self.logits_dense = nn.Linear(
+ config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
+ )
+ self.loss_type = "ForMaskedLM"
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.model.get_encoder()
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_position_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple, Seq2SeqLMOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
+ or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
+ 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
+ the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
+ tened audio logits which are used to calculate the loss.
+
+ 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
+ Dia to calculate embeddings and subsequent steps more efficiently.
+
+ If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
+ `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
+ [`DiaProcessor.__call__`] for more details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
+
+ [What are position IDs?](../glossary#position-ids)
+ labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in
+ `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
+ are ignored (masked).
+ """
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_position_ids=decoder_position_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_outputs=encoder_outputs,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ last_hidden_state = outputs[0]
+ batch_size = last_hidden_state.shape[0]
+ # 3D <-> 2D makes it necessary to prioritize channel dim
+ audio_logits = (
+ self.logits_dense(last_hidden_state)
+ .view((batch_size, -1, self.num_channels, self.vocab_size))
+ .transpose(1, 2)
+ .contiguous()
+ .view(batch_size * self.num_channels, -1, self.vocab_size)
+ )
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=audio_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+
+__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/modular_dia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/modular_dia.py
new file mode 100644
index 0000000000000000000000000000000000000000..f99d32a01d9cffab79bd72852d776447e084681b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/modular_dia.py
@@ -0,0 +1,773 @@
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Dia model."""
+
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...cache_utils import DynamicCache, EncoderDecoderCache
+from ...masking_utils import create_causal_mask
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
+from ..llama.modeling_llama import (
+ LlamaAttention,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ eager_attention_forward,
+)
+from ..phi3.modeling_phi3 import Phi3MLP
+from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
+from .generation_dia import DiaGenerationMixin
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+@auto_docstring
+class DiaPreTrainedModel(PreTrainedModel):
+ config: DiaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _can_compile_fullgraph = True
+ main_input_name = "input_ids"
+ _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
+
+
+class DiaMultiChannelEmbedding(nn.Module):
+ """In order to efficiently compute the audio embedding from the 9 different channels,
+ we vectorize the embedding process by using a single embedding layer and an offset.
+ Example:
+ - num_embeds = 4
+ - vocab_size = 8
+ - num_channels = 3
+ We would have offsets = [0, 8, 16]
+ If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
+ then tokens = audio_codes + offsets
+ = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
+ This allows us to use a single embedding layer for all channels.
+ """
+
+ def __init__(self, config: DiaDecoderConfig):
+ super().__init__()
+ self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
+ self.hidden_size = config.hidden_size
+ self.num_channels = config.num_channels
+ offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,)
+ self.register_buffer("offsets", offsets, persistent=False)
+
+ def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
+ tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
+ embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
+ return embeds.sum(dim=2)
+
+
+class DiaMLP(Phi3MLP):
+ pass
+
+
+class DiaRMSNorm(LlamaRMSNorm):
+ pass
+
+
+class DiaRotaryEmbedding(LlamaRotaryEmbedding):
+ pass
+
+
+class DiaSelfAttention(LlamaAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
+ nn.Module.__init__(self)
+ self.config = config
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+ self.num_heads = self.config.num_attention_heads
+ self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
+ self.scaling = 1
+ self.attention_dropout = 0.0
+ self.is_causal = is_causal
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+
+class DiaCrossAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: DiaDecoderConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+ self.cross_hidden_size = config.cross_hidden_size
+ self.num_heads = self.config.cross_num_attention_heads
+ self.num_key_value_heads = self.config.cross_num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.head_dim = config.cross_head_dim
+ self.scaling = 1
+ self.attention_dropout = 0.0
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
+ if past_key_values is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
+ value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
+ else:
+ key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
+ value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
+
+ if past_key_values is not None:
+ # save all states to the cache
+ key_states, value_states = past_key_values.cross_attention_cache.update(
+ key_states,
+ value_states,
+ self.layer_idx,
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ past_key_values.is_updated[self.layer_idx] = True
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class DiaEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DiaEncoderConfig, layer_idx: int):
+ super().__init__()
+ self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
+ self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.mlp = DiaMLP(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ residual = hidden_states
+ normed_states = self.pre_sa_norm(hidden_states)
+ self_attn_output, self_attn_weights = self.self_attention(
+ normed_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = residual + self_attn_output
+
+ residual = hidden_states
+ normed_states = self.post_sa_norm(hidden_states)
+ mlp_out = self.mlp(normed_states)
+ hidden_states = residual + mlp_out
+
+ return hidden_states, self_attn_weights
+
+
+class DiaEncoder(DiaPreTrainedModel):
+ def __init__(self, config: DiaEncoderConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.layers = nn.ModuleList(
+ [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.rotary_embeddings = DiaRotaryEmbedding(config)
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[BaseModelOutput, tuple]:
+ hidden_states = self.embedding(input_ids)
+
+ # RoPE
+ # Note: We expect right padding and hence always generate
+ # the position ids on the fly to reduce preparation overhead
+ position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
+ position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ layer_outputs = encoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states += (hidden_states,)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+
+class DiaDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DiaDecoderConfig, layer_idx: int):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
+ self.cross_attention = DiaCrossAttention(config, layer_idx)
+ self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.mlp = DiaMLP(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ self_attn_cache = past_key_values
+ if isinstance(self_attn_cache, EncoderDecoderCache):
+ self_attn_cache = self_attn_cache.self_attention_cache
+
+ residual = hidden_states
+ normed_states = self.pre_sa_norm(hidden_states)
+ self_attn_output, self_attn_weights = self.self_attention(
+ normed_states,
+ position_embeddings,
+ attention_mask,
+ # Needs to be an arg in order to function properly
+ # on inplace operations to be carried (e.g. compile)
+ self_attn_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + self_attn_output
+
+ residual = hidden_states
+ normed_states = self.pre_ca_norm(hidden_states)
+ cross_states, cross_attn_weights = self.cross_attention(
+ normed_states,
+ encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ **kwargs,
+ )
+ hidden_states = residual + cross_states
+
+ residual = hidden_states
+ normed_states = self.pre_mlp_norm(hidden_states)
+ mlp_out = self.mlp(normed_states)
+ hidden_states = residual + mlp_out
+
+ return hidden_states, self_attn_weights, cross_attn_weights
+
+
+class DiaDecoder(DiaPreTrainedModel):
+ """Transformer Decoder Stack using DenseGeneral."""
+
+ def __init__(self, config: DiaDecoderConfig):
+ super().__init__(config)
+ self.num_channels = config.num_channels
+ self.vocab_size = config.vocab_size
+ self.embeddings = DiaMultiChannelEmbedding(config)
+ self.rotary_embeddings = DiaRotaryEmbedding(config)
+ self.layers = nn.ModuleList(
+ [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
+ The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+
+ batch_size, seq_length = input_ids.size()[:-1]
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_key_values_length, past_key_values_length + seq_length, device=input_ids.device
+ )
+ if position_ids is None:
+ position_ids = cache_position[None, :]
+
+ # RoPE
+ hidden_states = self.embeddings(input_ids)
+ position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
+
+ if attention_mask is None and not is_torchdynamo_compiling():
+ # required mask seq length can be calculated via length of past cache
+ mask_seq_length = past_key_values_length + seq_length
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
+
+ attention_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=hidden_states,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ hidden_states.shape[:2],
+ hidden_states,
+ )
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ for layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = layer(
+ hidden_states,
+ position_embeddings,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns = all_self_attns + (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
+
+@auto_docstring(
+ custom_intro="""
+ The bare Dia model outputting raw hidden-states without any specific head on top.
+ """
+)
+class DiaModel(DiaPreTrainedModel):
+ def __init__(self, config: DiaConfig):
+ super().__init__(config)
+ self.config = config
+ self.encoder = DiaEncoder(config.encoder_config)
+ self.decoder = DiaDecoder(config.decoder_config)
+ self.post_init()
+
+ def get_encoder(self):
+ return self.encoder
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_position_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple, Seq2SeqModelOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
+ or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
+ 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
+ the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
+ tened audio logits which are used to calculate the loss.
+
+ 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
+ Dia to calculate embeddings and subsequent steps more efficiently.
+
+ If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
+ `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
+ [`DiaProcessor.__call__`] for more details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
+
+ [What are position IDs?](../glossary#position-ids)
+ """
+
+ if input_ids is None and encoder_outputs is None:
+ raise ValueError(
+ "You should either provide text ids or the cached text encodings. Neither has been found."
+ )
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if self.is_gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if use_cache and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ **kwargs,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
+ elif not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # On default we initialize the decoder with bos tokens if nothing has been provided
+ bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
+ if decoder_input_ids is None:
+ decoder_input_ids = torch.full(
+ size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device
+ )
+ # Ensure 3D
+ if decoder_input_ids.ndim == 2:
+ decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
+
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ position_ids=decoder_position_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs[0],
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
+ """
+)
+class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
+ base_model_prefix = "model"
+
+ def __init__(self, config: DiaConfig):
+ super().__init__(config)
+ self.config = config
+ self.model = DiaModel(config)
+
+ self.num_channels = config.decoder_config.num_channels
+ self.vocab_size = config.decoder_config.vocab_size
+ self.logits_dense = nn.Linear(
+ config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
+ )
+ self.loss_type = "ForMaskedLM"
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.model.get_encoder()
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_position_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple, Seq2SeqLMOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
+ or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
+ 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
+ the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
+ tened audio logits which are used to calculate the loss.
+
+ 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
+ Dia to calculate embeddings and subsequent steps more efficiently.
+
+ If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
+ `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
+ [`DiaProcessor.__call__`] for more details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
+
+ [What are position IDs?](../glossary#position-ids)
+ labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in
+ `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
+ are ignored (masked).
+ """
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_position_ids=decoder_position_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_outputs=encoder_outputs,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ last_hidden_state = outputs[0]
+ batch_size = last_hidden_state.shape[0]
+ # 3D <-> 2D makes it necessary to prioritize channel dim
+ audio_logits = (
+ self.logits_dense(last_hidden_state)
+ .view((batch_size, -1, self.num_channels, self.vocab_size))
+ .transpose(1, 2)
+ .contiguous()
+ .view(batch_size * self.num_channels, -1, self.vocab_size)
+ )
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=audio_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+
+__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/processing_dia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/processing_dia.py
new file mode 100644
index 0000000000000000000000000000000000000000..402f5152a64bda378ccdf5edd512c86fe643145c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/processing_dia.py
@@ -0,0 +1,474 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Processor class for Dia"""
+
+import math
+from pathlib import Path
+from typing import Optional, Union
+
+from ...audio_utils import AudioInput, make_list_of_audio
+from ...feature_extraction_utils import BatchFeature
+from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
+from ...utils import is_soundfile_available, is_torch_available
+
+
+if is_torch_available():
+ import torch
+
+if is_soundfile_available():
+ import soundfile as sf
+
+
+class DiaAudioKwargs(AudioKwargs, total=False):
+ bos_token_id: int
+ eos_token_id: int
+ pad_token_id: int
+ delay_pattern: list[int]
+ generation: bool
+
+
+class DiaProcessorKwargs(ProcessingKwargs, total=False):
+ audio_kwargs: DiaAudioKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": True,
+ "padding_side": "right",
+ "add_special_tokens": False,
+ },
+ "audio_kwargs": {
+ "eos_token_id": 1024,
+ "pad_token_id": 1025,
+ "bos_token_id": 1026,
+ "delay_pattern": [0, 8, 9, 10, 11, 12, 13, 14, 15],
+ "generation": True,
+ "sampling_rate": 44100,
+ },
+ "common_kwargs": {"return_tensors": "pt"},
+ }
+
+
+class DiaProcessor(ProcessorMixin):
+ r"""
+ Constructs a Dia processor which wraps a [`DiaFeatureExtractor`], [`DiaTokenizer`], and a [`DacModel`] into
+ a single processor. It inherits, the audio feature extraction, tokenizer, and audio encode/decode functio-
+ nalities. See [`~DiaProcessor.__call__`], [`~DiaProcessor.encode`], and [`~DiaProcessor.decode`] for more
+ information.
+
+ Args:
+ feature_extractor (`DiaFeatureExtractor`):
+ An instance of [`DiaFeatureExtractor`]. The feature extractor is a required input.
+ tokenizer (`DiaTokenizer`):
+ An instance of [`DiaTokenizer`]. The tokenizer is a required input.
+ audio_tokenizer (`DacModel`):
+ An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is is a required input.
+ """
+
+ feature_extractor_class = "DiaFeatureExtractor"
+ tokenizer_class = "DiaTokenizer"
+ audio_tokenizer_class = "DacModel"
+
+ def __init__(self, feature_extractor, tokenizer, audio_tokenizer):
+ super().__init__(feature_extractor, tokenizer, audio_tokenizer=audio_tokenizer)
+
+ def __call__(
+ self,
+ text: Union[str, list[str]],
+ audio: Optional[AudioInput] = None,
+ output_labels: Optional[bool] = False,
+ **kwargs: Unpack[DiaProcessorKwargs],
+ ):
+ """
+ Main method to prepare text(s) and audio to be fed as input to the model. The `audio` argument is
+ forwarded to the DiaFeatureExtractor's [`~DiaFeatureExtractor.__call__`] and subsequently to the
+ DacModel's [`~DacModel.encode`]. The `text` argument to [`~DiaTokenizer.__call__`]. Please refer
+ to the docstring of the above methods for more information.
+ """
+ if not is_torch_available():
+ raise ValueError(
+ "The `DiaProcessor` relies on the `audio_tokenizer` which requires `torch` but we couldn't "
+ "find it in your environment. You can install torch via `pip install torch`."
+ )
+
+ if text is None:
+ raise ValueError("You need to specify the `text` input to process.")
+
+ output_kwargs = self._merge_kwargs(
+ DiaProcessorKwargs,
+ **kwargs,
+ )
+
+ text_kwargs = output_kwargs["text_kwargs"]
+ audio_kwargs = output_kwargs["audio_kwargs"]
+ common_kwargs = output_kwargs["common_kwargs"]
+
+ return_tensors = common_kwargs.pop("return_tensors", None)
+ if return_tensors != "pt":
+ raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
+
+ data = {}
+
+ # Text
+ if isinstance(text, str):
+ text = [text]
+ elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+
+ encodings = self.tokenizer(text, **text_kwargs)
+ data.update(encodings)
+
+ # Audio
+ delay_pattern = audio_kwargs.pop("delay_pattern", None)
+ audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
+ audio_eos_token_id = audio_kwargs.pop("eos_token_id", None)
+ audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
+ generation = audio_kwargs.pop("generation", True)
+ if (
+ audio_bos_token_id is None
+ or audio_eos_token_id is None
+ or audio_pad_token_id is None
+ or delay_pattern is None
+ ):
+ raise ValueError(
+ "To enable processing for Dia, we need the `bos_token_id`, `eos_token_id`, "
+ "`pad_token_id`, and `delay_pattern`. You may have accidentally overwritten one of those."
+ )
+
+ if generation and output_labels:
+ raise ValueError(
+ f"Labels with `generation` is incompatible, got generation={generation}, output_labels={output_labels}."
+ )
+
+ batch_size = data["input_ids"].shape[0]
+ num_channels = len(delay_pattern)
+ max_delay = max(delay_pattern)
+
+ # Voice cloning generation / general training
+ if audio is not None:
+ audio = make_list_of_audio(audio)
+ input_audios = self.feature_extractor(audio, **audio_kwargs)
+
+ compression_rate = math.prod(self.audio_tokenizer.config.downsampling_ratios)
+ max_encoded_sequence_len = input_audios["padding_mask"][0].shape[-1] // compression_rate
+
+ decoder_input_ids = []
+ decoder_attention_mask = []
+ # TODO: dac with batching is currently broken, but non-batch is working
+ # refer to https://gist.github.com/vasqu/643a45b680cf39fd7467271ee2eb6f80 for a validation script
+ for padding_mask, audio in zip(input_audios["padding_mask"], input_audios["input_values"]):
+ # get current length with hop length in mind (as if it were sampled as a single audio)
+ base_pad_len = self.feature_extractor.hop_length
+ current_audio_len = math.ceil(padding_mask.sum(dim=-1) / base_pad_len) * base_pad_len
+
+ encoded_sequence_len = current_audio_len // compression_rate
+ padding_len = max_encoded_sequence_len - encoded_sequence_len
+
+ # compute non-padded forward pass; one extra bos (and eos if training) is added
+ with torch.no_grad():
+ audio = audio[None, ..., :current_audio_len].to(self.audio_tokenizer.device)
+ input_ids = self.audio_tokenizer.encode(audio).audio_codes.transpose(1, 2)
+
+ if not generation:
+ input_ids = torch.nn.functional.pad(
+ input_ids, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=audio_eos_token_id
+ )
+
+ # apply padding
+ # +1 for the bos within the real sequence
+ input_ids = torch.nn.functional.pad(
+ input_ids, pad=(0, 0, padding_len + 1, 0, 0, 0), mode="constant", value=audio_bos_token_id
+ )
+ num_valid_inputs = encoded_sequence_len + 1 + max_delay # sequence + bos + delay
+ num_valid_inputs += 0 if generation else 1 # eos if training
+ attention_mask = torch.tensor([0] * padding_len + [1] * num_valid_inputs, dtype=torch.long)[None, :]
+
+ decoder_input_ids.append(input_ids)
+ decoder_attention_mask.append(attention_mask)
+
+ decoder_input_ids = torch.cat(decoder_input_ids, dim=0)
+ decoder_attention_mask = torch.cat(decoder_attention_mask, dim=0)
+ # TTS generation
+ elif generation:
+ # all bos to start with TTS
+ decoder_input_ids = torch.full((batch_size, 1, num_channels), audio_bos_token_id, dtype=torch.long)
+
+ # we preemptively add the delay
+ decoder_attention_mask = torch.ones(size=(batch_size, 1 + max_delay), dtype=torch.long)
+ else:
+ raise ValueError("If you try to train, you should provide audio data as well.")
+
+ if batch_size != decoder_input_ids.shape[0]:
+ raise ValueError(
+ f"Need the same amount of samples for both text and audio, but got text samples={batch_size} and "
+ f"audio samples = {decoder_input_ids.shape[0]} instead."
+ )
+
+ # prepare shift indices per delay
+ max_seq_len = decoder_attention_mask.shape[-1]
+ max_audio_len = max_seq_len - max_delay
+ precomputed_idx = self.build_indices(
+ bsz=batch_size,
+ seq_len=max_seq_len,
+ num_channels=num_channels,
+ delay_pattern=delay_pattern,
+ revert=False,
+ )
+
+ # create delay pattern input
+ # the pad token will be used for masking which input is valid for prediction during generation
+ prefill = torch.full(
+ (batch_size, max_seq_len, num_channels),
+ fill_value=audio_pad_token_id,
+ dtype=torch.int,
+ )
+ prefill[:, :max_audio_len] = decoder_input_ids
+
+ delayed_decoder_input_ids = self.apply_audio_delay(
+ audio=prefill,
+ pad_token_id=audio_pad_token_id,
+ bos_token_id=audio_bos_token_id,
+ precomputed_idx=precomputed_idx,
+ )
+
+ data.update({"decoder_input_ids": delayed_decoder_input_ids, "decoder_attention_mask": decoder_attention_mask})
+
+ if output_labels:
+ # Base idea is to shift on the sequence dim
+ labels = data["decoder_input_ids"].clone()[:, 1:]
+ labels[labels == audio_pad_token_id] = -100
+ labels[labels == audio_bos_token_id] = -100
+
+ data["labels"] = labels.transpose(1, 2).reshape(batch_size * num_channels, -1).contiguous().long()
+ data["decoder_input_ids"] = data["decoder_input_ids"][:, :-1]
+ data["decoder_attention_mask"] = data["decoder_attention_mask"][:, :-1]
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def batch_decode(
+ self,
+ decoder_input_ids: "torch.Tensor",
+ audio_prompt_len: Optional[int] = None,
+ **kwargs: Unpack[DiaProcessorKwargs],
+ ) -> list["torch.Tensor"]:
+ """
+ Decodes a batch of audio codebook sequences into their respective audio waveforms via the
+ `audio_tokenizer`. See [`~DacModel.decode`] for more information.
+
+ Args:
+ decoder_input_ids (`torch.Tensor`): The complete output sequence of the decoder.
+ audio_prompt_len (`int`): The audio prefix length (e.g. when using voice cloning).
+ """
+ output_kwargs = self._merge_kwargs(
+ DiaProcessorKwargs,
+ **kwargs,
+ )
+ audio_kwargs = output_kwargs["audio_kwargs"]
+
+ delay_pattern = audio_kwargs.pop("delay_pattern", None)
+ audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
+ audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
+ if audio_bos_token_id is None or audio_pad_token_id is None or delay_pattern is None:
+ raise ValueError(
+ "To enable decoding for Dia, we need the `bos_token_id`, `pad_token_id`, "
+ "and `delay_pattern`. You may have accidentally overwritten one of those."
+ )
+
+ # either decode the whole audio sequence or only the generated parts
+ if audio_prompt_len is not None:
+ audio_prompt_len = torch.tensor(audio_prompt_len, device=decoder_input_ids.device, dtype=torch.long)
+ start_of_generation_idx = audio_prompt_len[None].expand(decoder_input_ids.shape[0])
+ else:
+ start_of_generation_idx = (decoder_input_ids[:, :, 0] == audio_bos_token_id).sum(dim=-1)
+ # -1 for the eos token
+ end_of_generation_idx = (
+ decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == audio_pad_token_id).sum(dim=-1) - 1
+ )
+
+ # revert delay
+ bsz, seq_len, num_channels = decoder_input_ids.shape
+ precomputed_idx = self.build_indices(
+ bsz=bsz,
+ seq_len=seq_len,
+ num_channels=num_channels,
+ delay_pattern=delay_pattern,
+ revert=True,
+ )
+
+ output_sequences = self.apply_audio_delay(
+ audio=decoder_input_ids,
+ # We do not care about these values as we cut them out
+ # with `start_of_generation_idx` and `end_of_generation_idx`
+ pad_token_id=-1,
+ bos_token_id=-1,
+ precomputed_idx=precomputed_idx,
+ ).transpose(1, 2)
+
+ # retrieve the correct sequences each
+ audios = []
+ # TODO: see above, dac doesn't work in batches yet
+ with torch.no_grad():
+ for i in range(start_of_generation_idx.shape[0]):
+ output_i = output_sequences[i, :, start_of_generation_idx[i] : end_of_generation_idx[i]][None, ...]
+ output_i = output_i.to(self.audio_tokenizer.device)
+ audio_i = self.audio_tokenizer.decode(audio_codes=output_i).audio_values.cpu().squeeze()
+ audios.append(audio_i)
+
+ return audios
+
+ def decode(
+ self,
+ decoder_input_ids: "torch.Tensor",
+ audio_prompt_len: Optional[int] = None,
+ **kwargs: Unpack[DiaProcessorKwargs],
+ ) -> "torch.Tensor":
+ """
+ Decodes a single sequence of audio codebooks into the respective audio waveform via the
+ `audio_tokenizer`. See [`~DacModel.decode`] and [`~DiaProcessor.batch_decode`] for more information.
+ """
+ if decoder_input_ids.shape[0] != 1:
+ raise ValueError(
+ f"Expecting a single output to be decoded but received {decoder_input_ids.shape[0]} samples instead."
+ )
+
+ return self.batch_decode(decoder_input_ids, audio_prompt_len, **kwargs)[0]
+
+ def get_audio_prompt_len(
+ self,
+ decoder_attention_mask: "torch.Tensor",
+ **kwargs: Unpack[DiaProcessorKwargs],
+ ) -> int:
+ """Utility function to get the audio prompt length."""
+ output_kwargs = self._merge_kwargs(
+ DiaProcessorKwargs,
+ **kwargs,
+ )
+ audio_kwargs = output_kwargs["audio_kwargs"]
+
+ delay_pattern = audio_kwargs.pop("delay_pattern", None)
+ if delay_pattern is None:
+ raise ValueError(
+ "To enable the utility of retrieving the prompt length for Dia, we need the "
+ "`delay_pattern`. You may have accidentally overwritten this."
+ )
+ return decoder_attention_mask.shape[1] - max(delay_pattern)
+
+ # Copied from transformers.models.csm.processing_csm.CsmProcessor.save_audio with Csm->Dia
+ def save_audio(
+ self,
+ audio: AudioInput,
+ saving_path: Union[str, Path, list[Union[str, Path]]],
+ **kwargs: Unpack[DiaProcessorKwargs],
+ ):
+ # TODO: @eustlb, this should be in AudioProcessor
+ if not is_soundfile_available():
+ raise ImportError("Please install `soundfile` to save audio files.")
+
+ # ensure correct audio input
+ audio = make_list_of_audio(audio)
+
+ # ensure correct saving path
+ if isinstance(saving_path, (str, Path)):
+ saving_path = [saving_path]
+ elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
+ raise ValueError("Invalid input path. Please provide a string, or a list of strings")
+
+ if len(audio) != len(saving_path):
+ raise ValueError("The number of audio and saving paths must be the same")
+
+ output_kwargs = self._merge_kwargs(
+ DiaProcessorKwargs,
+ **kwargs,
+ )
+ audio_kwargs = output_kwargs["audio_kwargs"]
+ sampling_rate = audio_kwargs["sampling_rate"]
+
+ for audio_value, p in zip(audio, saving_path):
+ if isinstance(audio_value, torch.Tensor):
+ audio_value = audio_value.cpu().float().numpy()
+ sf.write(p, audio_value, sampling_rate)
+
+ @staticmethod
+ def build_indices(
+ bsz: int,
+ seq_len: int,
+ num_channels: int,
+ delay_pattern: list[int],
+ revert: bool = False,
+ ) -> tuple["torch.Tensor", "torch.Tensor"]:
+ """
+ Precompute (sequence_idx, all_idx) so that out[seq, channel] = in[seq - delay[channel], channel]
+ or in[seq, channel] = out[seq + delay[channel], channel] if `revert`.
+ Negative sequence_idx => BOS; sequence_idx >= seq_len => PAD.
+ """
+ delay_array = torch.tensor(delay_pattern, dtype=torch.int32)
+
+ # (0..seq_len-1)
+ sequence_idx = torch.arange(seq_len, dtype=torch.int32)[None, :].expand(bsz, seq_len)[..., None]
+ # + or - delay depending if we delay or revert the delay
+ if not revert:
+ sequence_idx = sequence_idx - delay_array[None, None, :]
+ else:
+ sequence_idx = sequence_idx + delay_array[None, None, :]
+ # if delay goes over the range we clamp back to valid values
+ valid_sequence_idx = torch.clamp(sequence_idx, 0, seq_len - 1)
+
+ batch_idx = torch.arange(bsz, dtype=torch.int32)[:, None, None].expand(bsz, seq_len, num_channels)
+ channel_idx = torch.arange(num_channels, dtype=torch.int32)[None, None, :].expand(bsz, seq_len, num_channels)
+
+ all_idx = torch.stack(
+ [batch_idx.reshape(-1), valid_sequence_idx.reshape(-1), channel_idx.reshape(-1)],
+ dim=1,
+ ).long()
+
+ return sequence_idx, all_idx
+
+ @staticmethod
+ def apply_audio_delay(
+ audio: "torch.Tensor",
+ pad_token_id: int,
+ bos_token_id: int,
+ precomputed_idx: tuple["torch.Tensor", "torch.Tensor"],
+ ) -> "torch.Tensor":
+ """
+ Applies or reverts the delay pattern to batched audio tokens using precomputed indices,
+ inserting BOS where sequence_idx < 0 and PAD where sequence_idx >= seq_len.
+
+ Args:
+ audio: audio tokens of shape [bsz, seq_len, num_channels]
+ pad_token_id: the PAD token
+ bos_token_id: the BOS token
+ precomputed_idx: from `build_indices`
+
+ Returns:
+ final_audio: delayed or reverted audio tokens of shape [bsz, seq_len, num_channels]
+ """
+ # Move everything to the same device
+ device = audio.device
+ sequence_idx, all_idx = precomputed_idx
+ sequence_idx = sequence_idx.to(device)
+ all_idx = all_idx.to(device)
+
+ # Gather per precomputed indices
+ batch_idx, valid_sequence_idx, channel_idx = torch.unbind(all_idx, dim=-1)
+ gathered_audio = audio[batch_idx, valid_sequence_idx, channel_idx].view(audio.size())
+
+ # Mask according to negative sequence_idx => BOS; sequence_idx >= seq_len => PAD
+ mask_bos = sequence_idx < 0
+ mask_pad = sequence_idx >= audio.shape[1]
+ final_audio = torch.where(mask_bos, bos_token_id, torch.where(mask_pad, pad_token_id, gathered_audio))
+
+ return final_audio
+
+
+__all__ = ["DiaProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/tokenization_dia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/tokenization_dia.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e205906ea709ee2c20f25b0bf6f4fa66ab1f4a4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/tokenization_dia.py
@@ -0,0 +1,118 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for Dia."""
+
+from typing import Optional
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DiaTokenizer(PreTrainedTokenizer):
+ """
+ Construct a Dia tokenizer. Dia simply uses raw bytes utf-8 encoding except for special tokens `[S1]` and `[S2]`.
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ max_length (`int`, *optional*, defaults to 1024):
+ The maximum length of the sequences when encoding. Sequences longer than this will be truncated.
+ offset (`int`, *optional*, defaults to 0):
+ The offset of the tokenizer.
+ """
+
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ pad_token: Optional[str] = "",
+ unk_token: Optional[str] = "",
+ max_length: Optional[int] = 1024,
+ offset: int = 0,
+ **kwargs,
+ ):
+ # We have no eos/bos tokens but allow padding -- no l/r strip as we treat them as tokens as well
+ pad_token = AddedToken(pad_token) if isinstance(pad_token, str) else pad_token
+ unk_token = AddedToken(unk_token) if isinstance(unk_token, str) else unk_token
+
+ self._utf_vocab_size = 2**8 # utf is 8 bits
+ self._added_tokens_decoder = {0: pad_token, 1: AddedToken("[S1]"), 2: AddedToken("[S2]")}
+ self.offset = offset
+ super().__init__(
+ unk_token=unk_token,
+ pad_token=pad_token,
+ max_length=max_length,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self):
+ return self._utf_vocab_size
+
+ def get_vocab(self):
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text: str) -> list[str]:
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
+ tokens = [chr(i) for i in text.encode("utf-8")]
+ return tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+
+ if len(token) != 1:
+ token_id = None
+ else:
+ token_id = ord(token) + self.offset
+
+ return token_id
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ token = chr(index - self.offset)
+ return token
+
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
+ """Converts a sequence of tokens (string) in a single string."""
+ bstring = b""
+ for token in tokens:
+ if token in self.added_tokens_decoder:
+ added_token_obj = self.added_tokens_decoder[token]
+ tok_string = str(added_token_obj).encode("utf-8")
+ elif token in self.added_tokens_encoder:
+ tok_string = token.encode("utf-8")
+ else:
+ tok_string = token.encode("utf-8") # Assume general string token
+ bstring += tok_string
+ string = bstring.decode("utf-8", errors="ignore")
+ return string
+
+ # No vocab file
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ return ()
+
+
+__all__ = ["DiaTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dialogpt/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dialogpt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c162fce0a48bd164bd0e0a615b942ee4805a12aa
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_diffllama import *
+ from .modeling_diffllama import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/configuration_diffllama.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/configuration_diffllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..210607271927ab2f3a7aa1ec1e874fb296c32a73
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/configuration_diffllama.py
@@ -0,0 +1,199 @@
+# coding=utf-8
+# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on Llama implementations in this library and Microsoft's
+# Differential Transformer implementations.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DiffLlama model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class DiffLlamaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DiffLlamaModel`]. It is used to instantiate an DiffLlama
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults
+ will yield a similar configuration to that of the [kajuma/DiffLlama-0.3B-handcut](https://huggingface.co/kajuma/DiffLlama-0.3B-handcut).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the DiffLlama model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`DiffLlamaModel`]
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 8192):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 16):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'diffllama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'diffllama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'diffllama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'diffllama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ lambda_std_dev (`float`, *optional*, defaults to 0.1):
+ The standard deviation for initialization of parameter lambda in attention layer.
+ head_dim (`int`, *optional*):
+ The attention head dimension. If None, it will default to hidden_size // num_heads
+
+ ```python
+ >>> from transformers import DiffLlamaModel, DiffLlamaConfig
+
+ >>> # Initializing a DiffLlama diffllama-7b style configuration
+ >>> configuration = DiffLlamaConfig()
+
+ >>> # Initializing a model from the diffllama-7b style configuration
+ >>> model = DiffLlamaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "diffllama"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=2048,
+ intermediate_size=8192,
+ num_hidden_layers=16,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ lambda_std_dev=0.1,
+ head_dim=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.lambda_std_dev = lambda_std_dev
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+__all__ = ["DiffLlamaConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/modeling_diffllama.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/modeling_diffllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..094cc375057f71eb51644bf2b49c524613ed22e1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/modeling_diffllama.py
@@ -0,0 +1,767 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/diffllama/modular_diffllama.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_diffllama.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on Llama implementations in this library and Microsoft's
+# Differential Transformer implementations.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
+from ...modeling_layers import (
+ GenericForQuestionAnswering,
+ GenericForSequenceClassification,
+ GenericForTokenClassification,
+ GradientCheckpointingLayer,
+)
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_diffllama import DiffLlamaConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class DiffLlamaMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def lambda_init_fn(layer_idx):
+ return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
+
+
+class DiffLlamaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ # under this are not used
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+
+ self.lambda_init = lambda_init_fn(layer_idx)
+ self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, target_len, _ = hidden_states.size()
+ q_len = target_len
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
+ value_states = value_states.repeat(1, 2, 1, 1)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
+
+ attn_output = attn_output1 - lambda_full * attn_output2
+ attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class DiffLlamaFlashAttention2(DiffLlamaAttention):
+ """
+ DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> tuple[torch.Tensor, None]:
+ if isinstance(past_key_values, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (DiffLlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = (
+ torch.get_autocast_dtype(device_type)
+ if hasattr(torch, "get_autocast_dtype")
+ else torch.get_autocast_gpu_dtype()
+ )
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
+ value_states1 = value_states1.repeat(1, 1, 2, 1)
+ value_states2 = value_states2.repeat(1, 1, 2, 1)
+
+ attn_output1 = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states1,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output2 = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states2,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
+ attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)
+
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+ attn_output = attn_output1 - lambda_full * attn_output2
+ attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, None
+
+
+class DiffLlamaSdpaAttention(DiffLlamaAttention):
+ """
+ DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from DiffLlamaAttention.forward
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
+ value_states = value_states.repeat(1, 2, 1, 1)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = causal_mask is None and q_len > 1
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
+
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+ attn_output = attn_output1 - lambda_full * attn_output2
+ attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+ return attn_output, None
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class DiffLlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ DiffLlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+DIFFLLAMA_ATTENTION_CLASSES = {
+ "eager": DiffLlamaAttention,
+ "flash_attention_2": DiffLlamaFlashAttention2,
+ "sdpa": DiffLlamaSdpaAttention,
+}
+
+
+class DiffLlamaDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DiffLlamaConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+ self.mlp = DiffLlamaMLP(config)
+ self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class DiffLlamaPreTrainedModel(PreTrainedModel):
+ config: DiffLlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["DiffLlamaDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = False
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = False
+ _can_record_outputs = {
+ "hidden_states": DiffLlamaDecoderLayer,
+ "attentions": DiffLlamaAttention,
+ }
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, DiffLlamaAttention):
+ module.lambda_q1.data.normal_(0, self.config.lambda_std_dev)
+ module.lambda_k1.data.normal_(0, self.config.lambda_std_dev)
+ module.lambda_q2.data.normal_(0, self.config.lambda_std_dev)
+ module.lambda_k2.data.normal_(0, self.config.lambda_std_dev)
+
+
+class DiffLlamaRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: DiffLlamaConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@auto_docstring
+class DiffLlamaModel(DiffLlamaPreTrainedModel):
+ def __init__(self, config: DiffLlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = DiffLlamaRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = DiffLlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, DiffLlamaForCausalLM
+
+ >>> model = DiffLlamaForCausalLM.from_pretrained("google/diffllama-7b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/diffllama-7b")
+
+ >>> prompt = "What is your favorite condiment?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "What is your favorite condiment?"
+ ```"""
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class DiffLlamaForSequenceClassification(GenericForSequenceClassification, DiffLlamaPreTrainedModel):
+ pass
+
+
+class DiffLlamaForQuestionAnswering(GenericForQuestionAnswering, DiffLlamaPreTrainedModel):
+ base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
+
+
+class DiffLlamaForTokenClassification(GenericForTokenClassification, DiffLlamaPreTrainedModel):
+ pass
+
+
+__all__ = [
+ "DiffLlamaPreTrainedModel",
+ "DiffLlamaModel",
+ "DiffLlamaForCausalLM",
+ "DiffLlamaForSequenceClassification",
+ "DiffLlamaForQuestionAnswering",
+ "DiffLlamaForTokenClassification",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/modular_diffllama.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/modular_diffllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..253b99edff0d7a557da404fa680ce8403e22ccf1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/modular_diffllama.py
@@ -0,0 +1,447 @@
+# coding=utf-8
+# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on Llama implementations in this library and Microsoft's
+# Differential Transformer implementations.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Optional
+
+import torch
+from torch import nn
+
+from ...cache_utils import Cache, StaticCache
+from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
+from ...modeling_utils import PreTrainedModel
+from ...utils import logging
+from ...utils.deprecation import deprecate_kwarg
+from ..gemma.modeling_gemma import GemmaForCausalLM
+from ..llama.modeling_llama import (
+ LlamaDecoderLayer,
+ LlamaForQuestionAnswering,
+ LlamaForSequenceClassification,
+ LlamaForTokenClassification,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ apply_rotary_pos_emb,
+ repeat_kv,
+)
+from ..mistral.modeling_mistral import MistralMLP
+from .configuration_diffllama import DiffLlamaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
+_CONFIG_FOR_DOC = "DiffLlamaConfig"
+
+
+class DiffLlamaMLP(MistralMLP):
+ pass
+
+
+def lambda_init_fn(layer_idx):
+ return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
+
+
+class DiffLlamaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ # under this are not used
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+
+ self.lambda_init = lambda_init_fn(layer_idx)
+ self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, target_len, _ = hidden_states.size()
+ q_len = target_len
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
+ value_states = value_states.repeat(1, 2, 1, 1)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
+
+ attn_output = attn_output1 - lambda_full * attn_output2
+ attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class DiffLlamaFlashAttention2(DiffLlamaAttention):
+ """
+ DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> tuple[torch.Tensor, None]:
+ if isinstance(past_key_values, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (DiffLlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = (
+ torch.get_autocast_dtype(device_type)
+ if hasattr(torch, "get_autocast_dtype")
+ else torch.get_autocast_gpu_dtype()
+ )
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
+ value_states1 = value_states1.repeat(1, 1, 2, 1)
+ value_states2 = value_states2.repeat(1, 1, 2, 1)
+
+ attn_output1 = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states1,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output2 = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states2,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
+ attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)
+
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+ attn_output = attn_output1 - lambda_full * attn_output2
+ attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, None
+
+
+class DiffLlamaSdpaAttention(DiffLlamaAttention):
+ """
+ DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from DiffLlamaAttention.forward
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
+ value_states = value_states.repeat(1, 2, 1, 1)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = causal_mask is None and q_len > 1
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
+
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+ attn_output = attn_output1 - lambda_full * attn_output2
+ attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+ return attn_output, None
+
+
+DIFFLLAMA_ATTENTION_CLASSES = {
+ "eager": DiffLlamaAttention,
+ "flash_attention_2": DiffLlamaFlashAttention2,
+ "sdpa": DiffLlamaSdpaAttention,
+}
+
+
+class DiffLlamaDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: DiffLlamaConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+
+ self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+
+class DiffLlamaPreTrainedModel(LlamaPreTrainedModel):
+ _supports_flex_attn = False
+ _supports_attention_backend = False
+
+ def _init_weights(self, module):
+ PreTrainedModel._init_weights(self, module)
+ if isinstance(module, DiffLlamaAttention):
+ module.lambda_q1.data.normal_(0, self.config.lambda_std_dev)
+ module.lambda_k1.data.normal_(0, self.config.lambda_std_dev)
+ module.lambda_q2.data.normal_(0, self.config.lambda_std_dev)
+ module.lambda_k2.data.normal_(0, self.config.lambda_std_dev)
+
+
+class DiffLlamaModel(LlamaModel):
+ pass
+
+
+class DiffLlamaForCausalLM(GemmaForCausalLM):
+ pass
+
+
+class DiffLlamaForSequenceClassification(LlamaForSequenceClassification):
+ pass
+
+
+class DiffLlamaForQuestionAnswering(LlamaForQuestionAnswering):
+ pass
+
+
+class DiffLlamaForTokenClassification(LlamaForTokenClassification):
+ pass
+
+
+__all__ = [
+ "DiffLlamaPreTrainedModel",
+ "DiffLlamaModel",
+ "DiffLlamaForCausalLM",
+ "DiffLlamaForSequenceClassification",
+ "DiffLlamaForQuestionAnswering",
+ "DiffLlamaForTokenClassification",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b64cdbb3c7eb0467f6112225b8c0d9e1f65f9e99
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_dinat import *
+ from .modeling_dinat import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/configuration_dinat.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/configuration_dinat.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7d7fa509c5a3b2f5efc3b936cf1761b4ab0e107
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/configuration_dinat.py
@@ -0,0 +1,152 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Dilated Neighborhood Attention Transformer model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class DinatConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DinatModel`]. It is used to instantiate a Dinat
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Dinat
+ [shi-labs/dinat-mini-in1k-224](https://huggingface.co/shi-labs/dinat-mini-in1k-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ patch_size (`int`, *optional*, defaults to 4):
+ The size (resolution) of each patch. NOTE: Only patch size of 4 is supported at the moment.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ embed_dim (`int`, *optional*, defaults to 64):
+ Dimensionality of patch embedding.
+ depths (`list[int]`, *optional*, defaults to `[3, 4, 6, 5]`):
+ Number of layers in each level of the encoder.
+ num_heads (`list[int]`, *optional*, defaults to `[2, 4, 8, 16]`):
+ Number of attention heads in each layer of the Transformer encoder.
+ kernel_size (`int`, *optional*, defaults to 7):
+ Neighborhood Attention kernel size.
+ dilations (`list[list[int]]`, *optional*, defaults to `[[1, 8, 1], [1, 4, 1, 4], [1, 2, 1, 2, 1, 2], [1, 1, 1, 1, 1]]`):
+ Dilation value of each NA layer in the Transformer encoder.
+ mlp_ratio (`float`, *optional*, defaults to 3.0):
+ Ratio of MLP hidden dimensionality to embedding dimensionality.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not a learnable bias should be added to the queries, keys and values.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings and encoder.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
+ Stochastic depth rate.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
+ `"selu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ layer_scale_init_value (`float`, *optional*, defaults to 0.0):
+ The initial value for the layer scale. Disabled if <=0.
+ out_features (`list[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`list[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+
+ Example:
+
+ ```python
+ >>> from transformers import DinatConfig, DinatModel
+
+ >>> # Initializing a Dinat shi-labs/dinat-mini-in1k-224 style configuration
+ >>> configuration = DinatConfig()
+
+ >>> # Initializing a model (with random weights) from the shi-labs/dinat-mini-in1k-224 style configuration
+ >>> model = DinatModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "dinat"
+
+ attribute_map = {
+ "num_attention_heads": "num_heads",
+ "num_hidden_layers": "num_layers",
+ }
+
+ def __init__(
+ self,
+ patch_size=4,
+ num_channels=3,
+ embed_dim=64,
+ depths=[3, 4, 6, 5],
+ num_heads=[2, 4, 8, 16],
+ kernel_size=7,
+ dilations=[[1, 8, 1], [1, 4, 1, 4], [1, 2, 1, 2, 1, 2], [1, 1, 1, 1, 1]],
+ mlp_ratio=3.0,
+ qkv_bias=True,
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.1,
+ hidden_act="gelu",
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ layer_scale_init_value=0.0,
+ out_features=None,
+ out_indices=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.num_layers = len(depths)
+ self.num_heads = num_heads
+ self.kernel_size = kernel_size
+ self.dilations = dilations
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.drop_path_rate = drop_path_rate
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ # we set the hidden_size attribute in order to make Dinat work with VisionEncoderDecoderModel
+ # this indicates the channel dimension after the last stage of the model
+ self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
+ self.layer_scale_init_value = layer_scale_init_value
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+
+
+__all__ = ["DinatConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/modeling_dinat.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/modeling_dinat.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b7ec37b0ea8489e5db87df22c1876fd4548fe86
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/modeling_dinat.py
@@ -0,0 +1,855 @@
+# coding=utf-8
+# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Dilated Neighborhood Attention Transformer model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BackboneOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ ModelOutput,
+ OptionalDependencyNotAvailable,
+ auto_docstring,
+ is_natten_available,
+ logging,
+ requires_backends,
+)
+from ...utils.backbone_utils import BackboneMixin
+from .configuration_dinat import DinatConfig
+
+
+if is_natten_available():
+ from natten.functional import natten2dav, natten2dqkrpb
+else:
+
+ def natten2dqkrpb(*args, **kwargs):
+ raise OptionalDependencyNotAvailable()
+
+ def natten2dav(*args, **kwargs):
+ raise OptionalDependencyNotAvailable()
+
+
+logger = logging.get_logger(__name__)
+
+
+# drop_path and DinatDropPath are from the timm library.
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Dinat encoder's outputs, with potential hidden states and attentions.
+ """
+)
+class DinatEncoderOutput(ModelOutput):
+ r"""
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Dinat model's outputs that also contains a pooling of the last hidden states.
+ """
+)
+class DinatModelOutput(ModelOutput):
+ r"""
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
+ Average pooling of the last layer hidden-state.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Dinat outputs for image classification.
+ """
+)
+class DinatImageClassifierOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+class DinatEmbeddings(nn.Module):
+ """
+ Construct the patch and position embeddings.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.patch_embeddings = DinatPatchEmbeddings(config)
+
+ self.norm = nn.LayerNorm(config.embed_dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> tuple[torch.Tensor]:
+ embeddings = self.patch_embeddings(pixel_values)
+ embeddings = self.norm(embeddings)
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class DinatPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, height, width, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ patch_size = config.patch_size
+ num_channels, hidden_size = config.num_channels, config.embed_dim
+ self.num_channels = num_channels
+
+ if patch_size == 4:
+ pass
+ else:
+ # TODO: Support arbitrary patch sizes.
+ raise ValueError("Dinat only supports patch size of 4 at the moment.")
+
+ self.projection = nn.Sequential(
+ nn.Conv2d(self.num_channels, hidden_size // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
+ nn.Conv2d(hidden_size // 2, hidden_size, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
+ )
+
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> torch.Tensor:
+ _, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ embeddings = self.projection(pixel_values)
+ embeddings = embeddings.permute(0, 2, 3, 1)
+
+ return embeddings
+
+
+class DinatDownsampler(nn.Module):
+ """
+ Convolutional Downsampling Layer.
+
+ Args:
+ dim (`int`):
+ Number of input channels.
+ norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
+ Normalization layer class.
+ """
+
+ def __init__(self, dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
+ self.norm = norm_layer(2 * dim)
+
+ def forward(self, input_feature: torch.Tensor) -> torch.Tensor:
+ input_feature = self.reduction(input_feature.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+ input_feature = self.norm(input_feature)
+ return input_feature
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Dinat
+class DinatDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return f"p={self.drop_prob}"
+
+
+class NeighborhoodAttention(nn.Module):
+ def __init__(self, config, dim, num_heads, kernel_size, dilation):
+ super().__init__()
+ if dim % num_heads != 0:
+ raise ValueError(
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
+ )
+
+ self.num_attention_heads = num_heads
+ self.attention_head_size = int(dim / num_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.kernel_size = kernel_size
+ self.dilation = dilation
+
+ # rpb is learnable relative positional biases; same concept is used Swin.
+ self.rpb = nn.Parameter(torch.zeros(num_heads, (2 * self.kernel_size - 1), (2 * self.kernel_size - 1)))
+
+ self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.Tensor]:
+ batch_size, seq_length, _ = hidden_states.shape
+ query_layer = (
+ self.query(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+ key_layer = (
+ self.key(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+ value_layer = (
+ self.value(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+
+ # Apply the scale factor before computing attention weights. It's usually more efficient because
+ # attention weights are typically a bigger tensor compared to query.
+ # It gives identical results because scalars are commutable in matrix multiplication.
+ query_layer = query_layer / math.sqrt(self.attention_head_size)
+
+ # Compute NA between "query" and "key" to get the raw attention scores, and add relative positional biases.
+ attention_scores = natten2dqkrpb(query_layer, key_layer, self.rpb, self.kernel_size, self.dilation)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, self.dilation)
+ context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class NeighborhoodAttentionOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, dim)
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class NeighborhoodAttentionModule(nn.Module):
+ def __init__(self, config, dim, num_heads, kernel_size, dilation):
+ super().__init__()
+ self.self = NeighborhoodAttention(config, dim, num_heads, kernel_size, dilation)
+ self.output = NeighborhoodAttentionOutput(config, dim)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.Tensor]:
+ self_outputs = self.self(hidden_states, output_attentions)
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class DinatIntermediate(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class DinatOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class DinatLayer(nn.Module):
+ def __init__(self, config, dim, num_heads, dilation, drop_path_rate=0.0):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.kernel_size = config.kernel_size
+ self.dilation = dilation
+ self.window_size = self.kernel_size * self.dilation
+ self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.attention = NeighborhoodAttentionModule(
+ config, dim, num_heads, kernel_size=self.kernel_size, dilation=self.dilation
+ )
+ self.drop_path = DinatDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
+ self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.intermediate = DinatIntermediate(config, dim)
+ self.output = DinatOutput(config, dim)
+ self.layer_scale_parameters = (
+ nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True)
+ if config.layer_scale_init_value > 0
+ else None
+ )
+
+ def maybe_pad(self, hidden_states, height, width):
+ window_size = self.window_size
+ pad_values = (0, 0, 0, 0, 0, 0)
+ if height < window_size or width < window_size:
+ pad_l = pad_t = 0
+ pad_r = max(0, window_size - width)
+ pad_b = max(0, window_size - height)
+ pad_values = (0, 0, pad_l, pad_r, pad_t, pad_b)
+ hidden_states = nn.functional.pad(hidden_states, pad_values)
+ return hidden_states, pad_values
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size, height, width, channels = hidden_states.size()
+ shortcut = hidden_states
+
+ hidden_states = self.layernorm_before(hidden_states)
+ # pad hidden_states if they are smaller than kernel size x dilation
+ hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
+
+ _, height_pad, width_pad, _ = hidden_states.shape
+
+ attention_outputs = self.attention(hidden_states, output_attentions=output_attentions)
+
+ attention_output = attention_outputs[0]
+
+ was_padded = pad_values[3] > 0 or pad_values[5] > 0
+ if was_padded:
+ attention_output = attention_output[:, :height, :width, :].contiguous()
+
+ if self.layer_scale_parameters is not None:
+ attention_output = self.layer_scale_parameters[0] * attention_output
+
+ hidden_states = shortcut + self.drop_path(attention_output)
+
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.output(self.intermediate(layer_output))
+
+ if self.layer_scale_parameters is not None:
+ layer_output = self.layer_scale_parameters[1] * layer_output
+
+ layer_output = hidden_states + self.drop_path(layer_output)
+
+ layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
+ return layer_outputs
+
+
+class DinatStage(nn.Module):
+ def __init__(self, config, dim, depth, num_heads, dilations, drop_path_rate, downsample):
+ super().__init__()
+ self.config = config
+ self.dim = dim
+ self.layers = nn.ModuleList(
+ [
+ DinatLayer(
+ config=config,
+ dim=dim,
+ num_heads=num_heads,
+ dilation=dilations[i],
+ drop_path_rate=drop_path_rate[i],
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm)
+ else:
+ self.downsample = None
+
+ self.pointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.Tensor]:
+ _, height, width, _ = hidden_states.size()
+ for i, layer_module in enumerate(self.layers):
+ layer_outputs = layer_module(hidden_states, output_attentions)
+ hidden_states = layer_outputs[0]
+
+ hidden_states_before_downsampling = hidden_states
+ if self.downsample is not None:
+ hidden_states = self.downsample(hidden_states_before_downsampling)
+
+ stage_outputs = (hidden_states, hidden_states_before_downsampling)
+
+ if output_attentions:
+ stage_outputs += layer_outputs[1:]
+ return stage_outputs
+
+
+class DinatEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_levels = len(config.depths)
+ self.config = config
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
+ self.levels = nn.ModuleList(
+ [
+ DinatStage(
+ config=config,
+ dim=int(config.embed_dim * 2**i_layer),
+ depth=config.depths[i_layer],
+ num_heads=config.num_heads[i_layer],
+ dilations=config.dilations[i_layer],
+ drop_path_rate=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
+ downsample=DinatDownsampler if (i_layer < self.num_levels - 1) else None,
+ )
+ for i_layer in range(self.num_levels)
+ ]
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ output_hidden_states_before_downsampling: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[tuple, DinatEncoderOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_reshaped_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ # rearrange b h w c -> b c h w
+ reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ for i, layer_module in enumerate(self.levels):
+ layer_outputs = layer_module(hidden_states, output_attentions)
+
+ hidden_states = layer_outputs[0]
+ hidden_states_before_downsampling = layer_outputs[1]
+
+ if output_hidden_states and output_hidden_states_before_downsampling:
+ # rearrange b h w c -> b c h w
+ reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states_before_downsampling,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+ elif output_hidden_states and not output_hidden_states_before_downsampling:
+ # rearrange b h w c -> b c h w
+ reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ if output_attentions:
+ all_self_attentions += layer_outputs[2:]
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+ return DinatEncoderOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ reshaped_hidden_states=all_reshaped_hidden_states,
+ )
+
+
+@auto_docstring
+class DinatPreTrainedModel(PreTrainedModel):
+ config: DinatConfig
+ base_model_prefix = "dinat"
+ main_input_name = "pixel_values"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class DinatModel(DinatPreTrainedModel):
+ def __init__(self, config, add_pooling_layer=True):
+ r"""
+ add_pooling_layer (bool, *optional*, defaults to `True`):
+ Whether to add a pooling layer
+ """
+ super().__init__(config)
+
+ requires_backends(self, ["natten"])
+
+ self.config = config
+ self.num_levels = len(config.depths)
+ self.num_features = int(config.embed_dim * 2 ** (self.num_levels - 1))
+
+ self.embeddings = DinatEmbeddings(config)
+ self.encoder = DinatEncoder(config)
+
+ self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
+ self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, DinatModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+
+ pooled_output = None
+ if self.pooler is not None:
+ pooled_output = self.pooler(sequence_output.flatten(1, 2).transpose(1, 2))
+ pooled_output = torch.flatten(pooled_output, 1)
+
+ if not return_dict:
+ output = (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return output
+
+ return DinatModelOutput(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Dinat Model transformer with an image classification head on top (a linear layer on top of the final hidden state
+ of the [CLS] token) e.g. for ImageNet.
+ """
+)
+class DinatForImageClassification(DinatPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ requires_backends(self, ["natten"])
+
+ self.num_labels = config.num_labels
+ self.dinat = DinatModel(config)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(self.dinat.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, DinatImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.dinat(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return DinatImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ reshaped_hidden_states=outputs.reshaped_hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ NAT backbone, to be used with frameworks like DETR and MaskFormer.
+ """
+)
+class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ requires_backends(self, ["natten"])
+
+ self.embeddings = DinatEmbeddings(config)
+ self.encoder = DinatEncoder(config)
+ self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
+
+ # Add layer norms to hidden states of out_features
+ hidden_states_norms = {}
+ for stage, num_channels in zip(self._out_features, self.channels):
+ hidden_states_norms[stage] = nn.LayerNorm(num_channels)
+ self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BackboneOutput:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 512, 7, 7]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ embedding_output = self.embeddings(pixel_values)
+
+ outputs = self.encoder(
+ embedding_output,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ output_hidden_states_before_downsampling=True,
+ return_dict=True,
+ )
+
+ hidden_states = outputs.reshaped_hidden_states
+
+ feature_maps = ()
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ batch_size, num_channels, height, width = hidden_state.shape
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
+ hidden_state = hidden_state.view(batch_size, height * width, num_channels)
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
+ hidden_state = hidden_state.view(batch_size, height, width, num_channels)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ feature_maps += (hidden_state,)
+
+ if not return_dict:
+ output = (feature_maps,)
+ if output_hidden_states:
+ output += (outputs.hidden_states,)
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["DinatForImageClassification", "DinatModel", "DinatPreTrainedModel", "DinatBackbone"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d10027b6a3b6375235a6785df044e8f0ce5fb33
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_dinov2_with_registers import *
+ from .modeling_dinov2_with_registers import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec4f446fc684f40d634927c1e7a52b64c5732b12
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py
@@ -0,0 +1,159 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_dinov2_with_registers.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an
+ Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DINOv2 with Registers
+ [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ mlp_ratio (`int`, *optional*, defaults to 4):
+ Ratio of the hidden size of the MLPs relative to the `hidden_size`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ layerscale_value (`float`, *optional*, defaults to 1.0):
+ Initial value to use for layer scale.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
+ use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
+ Whether to use the SwiGLU feedforward neural network.
+ num_register_tokens (`int`, *optional*, defaults to 4):
+ Number of register tokens to use.
+ out_features (`list[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`list[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ apply_layernorm (`bool`, *optional*, defaults to `True`):
+ Whether to apply layer normalization to the feature maps in case the model is used as backbone.
+ reshape_hidden_states (`bool`, *optional*, defaults to `True`):
+ Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
+ case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
+ seq_len, hidden_size)`.
+
+ Example:
+
+ ```python
+ >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel
+
+ >>> # Initializing a Dinov2WithRegisters base style configuration
+ >>> configuration = Dinov2WithRegistersConfig()
+
+ >>> # Initializing a model (with random weights) from the base style configuration
+ >>> model = Dinov2WithRegistersModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "dinov2_with_registers"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ mlp_ratio=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ qkv_bias=True,
+ layerscale_value=1.0,
+ drop_path_rate=0.0,
+ use_swiglu_ffn=False,
+ num_register_tokens=4,
+ out_features=None,
+ out_indices=None,
+ apply_layernorm=True,
+ reshape_hidden_states=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.mlp_ratio = mlp_ratio
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.layerscale_value = layerscale_value
+ self.drop_path_rate = drop_path_rate
+ self.use_swiglu_ffn = use_swiglu_ffn
+ self.num_register_tokens = num_register_tokens
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+ self.apply_layernorm = apply_layernorm
+ self.reshape_hidden_states = reshape_hidden_states
+
+
+__all__ = ["Dinov2WithRegistersConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c6f4c335f58b2ddd37bf4042d5f8e51c474cee9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py
@@ -0,0 +1,712 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_dinov2_with_registers.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections.abc
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import TransformersKwargs, auto_docstring, torch_int
+from ...utils.backbone_utils import BackboneMixin
+from ...utils.generic import can_return_tuple, check_model_inputs
+from .configuration_dinov2_with_registers import Dinov2WithRegistersConfig
+
+
+class Dinov2WithRegistersPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ f" Expected {self.num_channels} but got {num_channels}."
+ )
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+class Dinov2WithRegistersEmbeddings(nn.Module):
+ """
+ Construct the CLS token, mask token, register tokens, position and patch embeddings.
+ """
+
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
+ self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
+ self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
+ with the original implementation.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+ - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
+ """
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+
+ # Skip interpolation for matching dimensions (unless tracing)
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ # Handle class token and patch embeddings separately
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+
+ # Calculate new dimensions
+ height = height // self.config.patch_size
+ width = width // self.config.patch_size
+
+ # Reshape for interpolation
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ # Store original dtype for restoration after interpolation
+ target_dtype = patch_pos_embed.dtype
+
+ # Interpolate at float32 precision
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.to(dtype=torch.float32),
+ size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor
+ mode="bicubic",
+ align_corners=False,
+ antialias=True,
+ ).to(dtype=target_dtype)
+
+ # Validate output dimensions if not tracing
+ if not torch.jit.is_tracing():
+ if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
+ raise ValueError("Width or height does not match with the interpolated position embeddings")
+
+ # Reshape back to original format
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ # Combine class and patch embeddings
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ target_dtype = self.patch_embeddings.projection.weight.dtype
+ embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
+
+ if bool_masked_pos is not None:
+ embeddings = torch.where(
+ bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
+ )
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+
+ # add register tokens
+ embeddings = torch.cat(
+ (embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1
+ )
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+
+ # Normalize the attention scores to probabilities.
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ # Mask heads if we want to
+ if attention_mask is not None:
+ attn_weights = attn_weights * attention_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Dinov2WithRegistersSelfAttention(nn.Module):
+ def __init__(self, config: Dinov2WithRegistersConfig):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.config = config
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.dropout_prob = config.attention_probs_dropout_prob
+ self.scaling = self.attention_head_size**-0.5
+ self.is_causal = False
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ def forward(
+ self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size = hidden_states.shape[0]
+ new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
+
+ key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
+ value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
+ query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ context_layer, attention_probs = attention_interface(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ head_mask,
+ is_causal=self.is_causal,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.dropout_prob,
+ )
+
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.reshape(new_context_layer_shape)
+
+ return context_layer, attention_probs
+
+
+class Dinov2WithRegistersSelfOutput(nn.Module):
+ """
+ The residual connection is defined in Dinov2WithRegistersLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: Dinov2WithRegistersConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class Dinov2WithRegistersAttention(nn.Module):
+ def __init__(self, config: Dinov2WithRegistersConfig):
+ super().__init__()
+ self.attention = Dinov2WithRegistersSelfAttention(config)
+ self.output = Dinov2WithRegistersSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: set[int]):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ self_attn_output, _ = self.attention(hidden_states, head_mask)
+ output = self.output(self_attn_output, hidden_states)
+ return output
+
+
+class Dinov2WithRegistersLayerScale(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ return hidden_state * self.lambda1
+
+
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+class Dinov2WithRegistersDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return f"p={self.drop_prob}"
+
+
+class Dinov2WithRegistersMLP(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
+ if isinstance(config.hidden_act, str):
+ self.activation = ACT2FN[config.hidden_act]
+ else:
+ self.activation = config.hidden_act
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.fc1(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.fc2(hidden_state)
+ return hidden_state
+
+
+class Dinov2WithRegistersSwiGLUFFN(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.weights_in(hidden_state)
+ x1, x2 = hidden_state.chunk(2, dim=-1)
+ hidden = nn.functional.silu(x1) * x2
+ return self.weights_out(hidden)
+
+
+class Dinov2WithRegistersLayer(GradientCheckpointingLayer):
+ """This corresponds to the Block class in the original implementation."""
+
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attention = Dinov2WithRegistersAttention(config)
+ self.layer_scale1 = Dinov2WithRegistersLayerScale(config)
+ self.drop_path = (
+ Dinov2WithRegistersDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+ )
+
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ if config.use_swiglu_ffn:
+ self.mlp = Dinov2WithRegistersSwiGLUFFN(config)
+ else:
+ self.mlp = Dinov2WithRegistersMLP(config)
+ self.layer_scale2 = Dinov2WithRegistersLayerScale(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states_norm = self.norm1(hidden_states)
+ self_attention_output = self.attention(hidden_states_norm, head_mask)
+ self_attention_output = self.layer_scale1(self_attention_output)
+
+ # first residual connection
+ hidden_states = self.drop_path(self_attention_output) + hidden_states
+
+ # in Dinov2WithRegisters, layernorm is also applied after self-attention
+ layer_output = self.norm2(hidden_states)
+ layer_output = self.mlp(layer_output)
+ layer_output = self.layer_scale2(layer_output)
+
+ # second residual connection
+ layer_output = self.drop_path(layer_output) + hidden_states
+
+ return layer_output
+
+
+class Dinov2WithRegistersEncoder(nn.Module):
+ def __init__(self, config: Dinov2WithRegistersConfig):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([Dinov2WithRegistersLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False
+ ) -> BaseModelOutput:
+ all_hidden_states = [hidden_states] if output_hidden_states else None
+ for i, layer_module in enumerate(self.layer):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ hidden_states = layer_module(hidden_states, layer_head_mask)
+ if all_hidden_states:
+ all_hidden_states.append(hidden_states)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
+ )
+
+
+@auto_docstring
+class Dinov2WithRegistersPreTrainedModel(PreTrainedModel):
+ config: Dinov2WithRegistersConfig
+ base_model_prefix = "dinov2_with_registers"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Dinov2WithRegistersLayer"]
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "attentions": Dinov2WithRegistersSelfAttention,
+ }
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Dinov2WithRegistersEmbeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+
+ module.cls_token.data = nn.init.trunc_normal_(
+ module.cls_token.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.cls_token.dtype)
+
+ module.mask_token.data.zero_()
+ module.register_tokens.data.zero_()
+ elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821
+ module.lambda1.data.fill_(self.config.layerscale_value)
+
+
+@auto_docstring
+class Dinov2WithRegistersModel(Dinov2WithRegistersPreTrainedModel):
+ def __init__(self, config: Dinov2WithRegistersConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = Dinov2WithRegistersEmbeddings(config)
+ self.encoder = Dinov2WithRegistersEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ **kwargs,
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
+ pre-training.
+ """
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ embedding_output, head_mask=head_mask, output_hidden_states=output_hidden_states
+ )
+ sequence_output = encoder_outputs.last_hidden_state
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = sequence_output[:, 0, :]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Dinov2WithRegisters Model transformer with an image classification head on top (a linear layer on top of the final hidden state
+ of the [CLS] token) e.g. for ImageNet.
+ """
+)
+class Dinov2WithRegistersForImageClassification(Dinov2WithRegistersPreTrainedModel):
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.dinov2_with_registers = Dinov2WithRegistersModel(config)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> ImageClassifierOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ outputs: BaseModelOutputWithPooling = self.dinov2_with_registers(pixel_values, head_mask=head_mask, **kwargs)
+ sequence_output = outputs.last_hidden_state # batch_size, sequence_length, hidden_size
+
+ cls_token = sequence_output[:, 0]
+ # cls and register tokens should not be included in patch tokens variable
+ patch_tokens = sequence_output[:, 1 + self.config.num_register_tokens :]
+
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
+ logits = self.classifier(linear_input)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config, **kwargs)
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Dinov2WithRegisters backbone, to be used with frameworks like DETR and MaskFormer.
+ """
+)
+class Dinov2WithRegistersBackbone(Dinov2WithRegistersPreTrainedModel, BackboneMixin):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+ self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
+ self.embeddings = Dinov2WithRegistersEmbeddings(config)
+ self.encoder = Dinov2WithRegistersEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ self.num_register_tokens = config.num_register_tokens
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ **kwargs,
+ ) -> BackboneOutput:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 768, 16, 16]
+ ```"""
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ embedding_output = self.embeddings(pixel_values)
+ output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True)
+ hidden_states = output.hidden_states
+
+ feature_maps = []
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ if self.config.apply_layernorm:
+ hidden_state = self.layernorm(hidden_state)
+ if self.config.reshape_hidden_states:
+ hidden_state = hidden_state[:, 1 + self.num_register_tokens :]
+ # this was actually a bug in the original implementation that we copied here,
+ # cause normally the order is height, width
+ batch_size, _, height, width = pixel_values.shape
+ patch_size = self.config.patch_size
+ hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ feature_maps.append(hidden_state)
+
+ return BackboneOutput(
+ feature_maps=tuple(feature_maps),
+ hidden_states=hidden_states if output_hidden_states else None,
+ )
+
+
+__all__ = [
+ "Dinov2WithRegistersPreTrainedModel",
+ "Dinov2WithRegistersModel",
+ "Dinov2WithRegistersForImageClassification",
+ "Dinov2WithRegistersBackbone",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py
new file mode 100644
index 0000000000000000000000000000000000000000..686528002b09c9689d66a057ed55eb1a43b0d256
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py
@@ -0,0 +1,435 @@
+# coding=utf-8
+# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ....transformers.models.dinov2.modeling_dinov2 import (
+ Dinov2Backbone,
+ Dinov2Encoder,
+ Dinov2ForImageClassification,
+ Dinov2Model,
+ Dinov2PatchEmbeddings,
+ Dinov2PreTrainedModel,
+)
+from ...configuration_utils import PretrainedConfig
+from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, logging, torch_int
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an
+ Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DINOv2 with Registers
+ [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ mlp_ratio (`int`, *optional*, defaults to 4):
+ Ratio of the hidden size of the MLPs relative to the `hidden_size`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ layerscale_value (`float`, *optional*, defaults to 1.0):
+ Initial value to use for layer scale.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
+ use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
+ Whether to use the SwiGLU feedforward neural network.
+ num_register_tokens (`int`, *optional*, defaults to 4):
+ Number of register tokens to use.
+ out_features (`list[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`list[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ apply_layernorm (`bool`, *optional*, defaults to `True`):
+ Whether to apply layer normalization to the feature maps in case the model is used as backbone.
+ reshape_hidden_states (`bool`, *optional*, defaults to `True`):
+ Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
+ case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
+ seq_len, hidden_size)`.
+
+ Example:
+
+ ```python
+ >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel
+
+ >>> # Initializing a Dinov2WithRegisters base style configuration
+ >>> configuration = Dinov2WithRegistersConfig()
+
+ >>> # Initializing a model (with random weights) from the base style configuration
+ >>> model = Dinov2WithRegistersModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "dinov2_with_registers"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ mlp_ratio=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ qkv_bias=True,
+ layerscale_value=1.0,
+ drop_path_rate=0.0,
+ use_swiglu_ffn=False,
+ num_register_tokens=4,
+ out_features=None,
+ out_indices=None,
+ apply_layernorm=True,
+ reshape_hidden_states=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.mlp_ratio = mlp_ratio
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.layerscale_value = layerscale_value
+ self.drop_path_rate = drop_path_rate
+ self.use_swiglu_ffn = use_swiglu_ffn
+ self.num_register_tokens = num_register_tokens
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+ self.apply_layernorm = apply_layernorm
+ self.reshape_hidden_states = reshape_hidden_states
+
+
+class Dinov2WithRegistersPatchEmbeddings(Dinov2PatchEmbeddings):
+ pass
+
+
+class Dinov2WithRegistersEmbeddings(nn.Module):
+ """
+ Construct the CLS token, mask token, register tokens, position and patch embeddings.
+ """
+
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
+ self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
+ self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
+ with the original implementation.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+ - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
+ """
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+
+ # Skip interpolation for matching dimensions (unless tracing)
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ # Handle class token and patch embeddings separately
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+
+ # Calculate new dimensions
+ height = height // self.config.patch_size
+ width = width // self.config.patch_size
+
+ # Reshape for interpolation
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ # Store original dtype for restoration after interpolation
+ target_dtype = patch_pos_embed.dtype
+
+ # Interpolate at float32 precision
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.to(dtype=torch.float32),
+ size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor
+ mode="bicubic",
+ align_corners=False,
+ antialias=True,
+ ).to(dtype=target_dtype)
+
+ # Validate output dimensions if not tracing
+ if not torch.jit.is_tracing():
+ if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
+ raise ValueError("Width or height does not match with the interpolated position embeddings")
+
+ # Reshape back to original format
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ # Combine class and patch embeddings
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ target_dtype = self.patch_embeddings.projection.weight.dtype
+ embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
+
+ if bool_masked_pos is not None:
+ embeddings = torch.where(
+ bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
+ )
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+
+ # add register tokens
+ embeddings = torch.cat(
+ (embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1
+ )
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class Dinov2WithRegistersEncoder(Dinov2Encoder):
+ pass
+
+
+class Dinov2WithRegistersPreTrainedModel(Dinov2PreTrainedModel):
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Dinov2WithRegistersEmbeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+
+ module.cls_token.data = nn.init.trunc_normal_(
+ module.cls_token.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.cls_token.dtype)
+
+ module.mask_token.data.zero_()
+ module.register_tokens.data.zero_()
+ elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821
+ module.lambda1.data.fill_(self.config.layerscale_value)
+
+
+class Dinov2WithRegistersModel(Dinov2Model):
+ pass
+
+
+class Dinov2WithRegistersForImageClassification(Dinov2ForImageClassification):
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> ImageClassifierOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ outputs: BaseModelOutputWithPooling = self.dinov2_with_registers(pixel_values, head_mask=head_mask, **kwargs)
+ sequence_output = outputs.last_hidden_state # batch_size, sequence_length, hidden_size
+
+ cls_token = sequence_output[:, 0]
+ # cls and register tokens should not be included in patch tokens variable
+ patch_tokens = sequence_output[:, 1 + self.config.num_register_tokens :]
+
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
+ logits = self.classifier(linear_input)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config, **kwargs)
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class Dinov2WithRegistersBackbone(Dinov2Backbone):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.num_register_tokens = config.num_register_tokens
+ self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
+ self.embeddings = Dinov2WithRegistersEmbeddings(config)
+ self.encoder = Dinov2WithRegistersEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ **kwargs,
+ ) -> BackboneOutput:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 768, 16, 16]
+ ```"""
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ embedding_output = self.embeddings(pixel_values)
+ output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True)
+ hidden_states = output.hidden_states
+
+ feature_maps = []
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ if self.config.apply_layernorm:
+ hidden_state = self.layernorm(hidden_state)
+ if self.config.reshape_hidden_states:
+ hidden_state = hidden_state[:, 1 + self.num_register_tokens :]
+ # this was actually a bug in the original implementation that we copied here,
+ # cause normally the order is height, width
+ batch_size, _, height, width = pixel_values.shape
+ patch_size = self.config.patch_size
+ hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ feature_maps.append(hidden_state)
+
+ return BackboneOutput(
+ feature_maps=tuple(feature_maps),
+ hidden_states=hidden_states if output_hidden_states else None,
+ )
+
+
+__all__ = [
+ "Dinov2WithRegistersConfig",
+ "Dinov2WithRegistersPreTrainedModel",
+ "Dinov2WithRegistersModel",
+ "Dinov2WithRegistersForImageClassification",
+ "Dinov2WithRegistersBackbone",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dit/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c786feb9213fdd31640c0fdeaead5164026ad37a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_encoder_decoder import *
+ from .modeling_encoder_decoder import *
+ from .modeling_flax_encoder_decoder import *
+ from .modeling_tf_encoder_decoder import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/configuration_encoder_decoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/configuration_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..af57b2596cee99eefe0493cc4aea51c845036d2e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/configuration_encoder_decoder.py
@@ -0,0 +1,111 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class EncoderDecoderConfig(PretrainedConfig):
+ r"""
+ [`EncoderDecoderConfig`] is the configuration class to store the configuration of a [`EncoderDecoderModel`]. It is
+ used to instantiate an Encoder Decoder model according to the specified arguments, defining the encoder and decoder
+ configs.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ kwargs (*optional*):
+ Dictionary of keyword arguments. Notably:
+
+ - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
+ the encoder config.
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
+ the decoder config.
+
+ Examples:
+
+ ```python
+ >>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel
+
+ >>> # Initializing a BERT google-bert/bert-base-uncased style configuration
+ >>> config_encoder = BertConfig()
+ >>> config_decoder = BertConfig()
+
+ >>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
+
+ >>> # Initializing a Bert2Bert model (with random weights) from the google-bert/bert-base-uncased style configurations
+ >>> model = EncoderDecoderModel(config=config)
+
+ >>> # Accessing the model configuration
+ >>> config_encoder = model.config.encoder
+ >>> config_decoder = model.config.decoder
+ >>> # set decoder config to causal lm
+ >>> config_decoder.is_decoder = True
+ >>> config_decoder.add_cross_attention = True
+
+ >>> # Saving the model, including its configuration
+ >>> model.save_pretrained("my-model")
+
+ >>> # loading model and config from pretrained folder
+ >>> encoder_decoder_config = EncoderDecoderConfig.from_pretrained("my-model")
+ >>> model = EncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
+ ```"""
+
+ model_type = "encoder-decoder"
+ sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig}
+ has_no_defaults_at_init = True
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ if "encoder" not in kwargs or "decoder" not in kwargs:
+ raise ValueError(
+ f"A configuration of type {self.model_type} cannot be instantiated because "
+ f"both `encoder` and `decoder` sub-configurations were not passed, only {kwargs}"
+ )
+ encoder_config = kwargs.pop("encoder")
+ encoder_model_type = encoder_config.pop("model_type")
+ decoder_config = kwargs.pop("decoder")
+ decoder_model_type = decoder_config.pop("model_type")
+
+ self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
+ self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
+ self.is_encoder_decoder = True
+
+ @classmethod
+ def from_encoder_decoder_configs(
+ cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
+ ) -> PretrainedConfig:
+ r"""
+ Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and
+ decoder model configuration.
+
+ Returns:
+ [`EncoderDecoderConfig`]: An instance of a configuration object
+ """
+ logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
+ decoder_config.is_decoder = True
+ decoder_config.add_cross_attention = True
+
+ return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
+
+
+__all__ = ["EncoderDecoderConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..30e2370b2240d2f2f00182abea548cd6a72b5626
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py
@@ -0,0 +1,609 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Classes to support Encoder-Decoder architectures"""
+
+import gc
+import inspect
+import os
+import tempfile
+import warnings
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...cache_utils import Cache
+from ...configuration_utils import PretrainedConfig
+from ...generation import GenerationMixin
+from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging
+from ..auto.configuration_auto import AutoConfig
+from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM
+from .configuration_encoder_decoder import EncoderDecoderConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+DEPRECATION_WARNING = (
+ "Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the"
+ " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if"
+ " fine-tuning a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the"
+ " labels, no need to pass them yourself anymore."
+)
+
+
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ if decoder_start_token_id is None:
+ raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+@auto_docstring
+class EncoderDecoderModel(PreTrainedModel, GenerationMixin):
+ r"""
+ [`EncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
+ of the base model classes of the library as encoder and another one as decoder when created with the
+ :meth*~transformers.AutoModel.from_pretrained* class method for the encoder and
+ :meth*~transformers.AutoModelForCausalLM.from_pretrained* class method for the decoder.
+ """
+
+ config: EncoderDecoderConfig
+ base_model_prefix = "encoder_decoder"
+ main_input_name = "input_ids"
+ supports_gradient_checkpointing = True
+ _supports_param_buffer_assignment = False
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ def __init__(
+ self,
+ config: Optional[PretrainedConfig] = None,
+ encoder: Optional[PreTrainedModel] = None,
+ decoder: Optional[PreTrainedModel] = None,
+ ):
+ r"""
+ encoder (`PreTrainedModel`, *optional*):
+ The encoder model to use.
+ decoder (`PreTrainedModel`, *optional*):
+ The decoder model to use.
+ """
+ if config is None and (encoder is None or decoder is None):
+ raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
+ if config is None:
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
+ else:
+ if not isinstance(config, self.config_class):
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
+
+ if config.decoder.cross_attention_hidden_size is not None:
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
+ raise ValueError(
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
+ )
+
+ # initialize with config
+ super().__init__(config)
+
+ if encoder is None:
+ from ..auto.modeling_auto import AutoModel
+
+ encoder = AutoModel.from_config(config.encoder)
+
+ if decoder is None:
+ from ..auto.modeling_auto import AutoModelForCausalLM
+
+ decoder = AutoModelForCausalLM.from_config(config.decoder)
+
+ self.encoder = encoder
+ self.decoder = decoder
+
+ if self.encoder.config.to_dict() != self.config.encoder.to_dict():
+ logger.warning(
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
+ f" {self.config.encoder}"
+ )
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
+ logger.warning(
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
+ )
+
+ # make sure that the individual model's config refers to the shared config
+ # so that the updates to the config will be synced
+ # update `_attn_implementation` because the attn is set in a deepcopied config within PreTrainedModel
+ self.config.encoder._attn_implementation = self.encoder.config._attn_implementation
+ self.config.decoder._attn_implementation = self.decoder.config._attn_implementation
+ self.encoder.config = self.config.encoder
+ self.decoder.config = self.config.decoder
+
+ # encoder outputs might need to be projected to different dimension for decoder
+ if (
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
+
+ if self.encoder.get_output_embeddings() is not None:
+ raise ValueError(
+ f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
+ )
+
+ decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys())
+ if "encoder_hidden_states" not in decoder_signature:
+ raise ValueError(
+ "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
+ "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
+ )
+
+ # tie encoder, decoder weights if config set accordingly
+ self.tie_weights()
+
+ def tie_weights(self):
+ self.encoder.tie_weights()
+ self.decoder.tie_weights()
+ # tie encoder & decoder if needed
+ if self.config.tie_encoder_decoder:
+ # tie encoder and decoder base model
+ decoder_base_model_prefix = self.decoder.base_model_prefix
+ tied_weights = self._tie_encoder_decoder_weights(
+ self.encoder,
+ self.decoder._modules[decoder_base_model_prefix],
+ self.decoder.base_model_prefix,
+ "encoder",
+ )
+ # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
+ # attributed not an instance member, therefore modifying it will modify the entire class
+ # Leading to issues on subsequent calls by different tests or subsequent calls.
+ self._dynamic_tied_weights_keys = tied_weights
+
+ def _init_weights(self, module):
+ if module in self.encoder.modules():
+ self.encoder._init_weights(module)
+ elif module in self.decoder.modules():
+ self.decoder._init_weights(module)
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_input_embeddings(self):
+ return self.encoder.get_input_embeddings()
+
+ def get_output_embeddings(self):
+ return self.decoder.get_output_embeddings()
+
+ def set_output_embeddings(self, new_embeddings):
+ return self.decoder.set_output_embeddings(new_embeddings)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import EncoderDecoderModel
+
+ >>> model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
+ ```"""
+
+ from_tf = kwargs.pop("from_tf", False)
+ if from_tf:
+ from transformers import TFEncoderDecoderModel
+
+ # a workaround to load from tensorflow checkpoint
+ # Using `_tf_model` won't work, because the weight names in the encoder/decoder of `_tf_model` get
+ # extended before saving those components. For example, The name of `_tf_model.encoder.vit` is
+ # `[top model name]/encoder/vit`, but the name of `tf_model.encoder.vit` is `[top model name]/vit`. The
+ # [top model name] is handled (stripped) by the conversion method, and the former case gets extra `encoder`,
+ # which should not occur when we want to save the components alone.
+ # There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see
+ # https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
+ # (the change in `src/transformers/modeling_tf_utils.py`)
+ _tf_model = TFEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+ config = _tf_model.config
+
+ # Using `tf_model` instead
+ encoder = _tf_model.encoder.__class__(_tf_model.config.encoder)
+ decoder = _tf_model.decoder.__class__(_tf_model.config.decoder)
+ # Make sure models are built
+ encoder(encoder.dummy_inputs)
+ decoder(decoder.dummy_inputs)
+
+ # Get the variable correspondence between `_tf_model` and `encoder` and `decoder`
+ encoder_variables = {}
+ for v in encoder.trainable_variables + encoder.non_trainable_variables:
+ encoder_variables["/".join(v.name.split("/")[1:])] = v
+ decoder_variables = {}
+ for v in decoder.trainable_variables + decoder.non_trainable_variables:
+ decoder_variables["/".join(v.name.split("/")[1:])] = v
+
+ _encoder_variables = {}
+ for v in _tf_model.encoder.trainable_variables + _tf_model.encoder.non_trainable_variables:
+ _encoder_variables["/".join(v.name.split("/")[2:])] = v
+ _decoder_variables = {}
+ for v in _tf_model.decoder.trainable_variables + _tf_model.decoder.non_trainable_variables:
+ _decoder_variables["/".join(v.name.split("/")[2:])] = v
+
+ # assign weight values to `encoder` and `decoder` from `_tf_model`
+ for name, v in encoder_variables.items():
+ v.assign(_encoder_variables[name])
+ for name, v in decoder_variables.items():
+ v.assign(_decoder_variables[name])
+
+ tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
+
+ # Deal with `enc_to_dec_proj`
+ if hasattr(_tf_model, "enc_to_dec_proj"):
+ tf_model(tf_model.dummy_inputs)
+ tf_model.enc_to_dec_proj.kernel.assign(_tf_model.enc_to_dec_proj.kernel)
+ tf_model.enc_to_dec_proj.bias.assign(_tf_model.enc_to_dec_proj.bias)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ encoder_dir = os.path.join(tmpdirname, "encoder")
+ decoder_dir = os.path.join(tmpdirname, "decoder")
+ tf_model.encoder.save_pretrained(encoder_dir)
+ tf_model.decoder.save_pretrained(decoder_dir)
+
+ if hasattr(tf_model, "enc_to_dec_proj"):
+ enc_to_dec_proj_weight = torch.transpose(
+ torch.from_numpy(tf_model.enc_to_dec_proj.kernel.numpy()), 1, 0
+ )
+ enc_to_dec_proj_bias = torch.from_numpy(tf_model.enc_to_dec_proj.bias.numpy())
+
+ del _tf_model
+ del tf_model
+ gc.collect()
+
+ model = EncoderDecoderModel.from_encoder_decoder_pretrained(
+ encoder_dir, decoder_dir, encoder_from_tf=True, decoder_from_tf=True
+ )
+ # This is only for copying some specific attributes of this particular model.
+ model.config = config
+
+ if hasattr(model, "enc_to_dec_proj"):
+ model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous()
+ model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous()
+
+ return model
+
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+
+ @classmethod
+ def from_encoder_decoder_pretrained(
+ cls,
+ encoder_pretrained_model_name_or_path: Optional[str] = None,
+ decoder_pretrained_model_name_or_path: Optional[str] = None,
+ *model_args,
+ **kwargs,
+ ) -> PreTrainedModel:
+ r"""
+ Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
+ checkpoints.
+
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you need to first set it back in training mode with `model.train()`.
+
+ Params:
+ encoder_pretrained_model_name_or_path (`str`, *optional*):
+ Information necessary to initiate the encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+
+ decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
+ Information necessary to initiate the decoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+
+ model_args (remaining positional arguments, *optional*):
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`).
+
+ - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+ Example:
+
+ ```python
+ >>> from transformers import EncoderDecoderModel
+
+ >>> # initialize a bert2bert from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized
+ >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "google-bert/bert-base-uncased")
+ >>> # saving model after fine-tuning
+ >>> model.save_pretrained("./bert2bert")
+ >>> # load fine-tuned model
+ >>> model = EncoderDecoderModel.from_pretrained("./bert2bert")
+ ```"""
+
+ kwargs_encoder = {
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
+ }
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ # remove encoder, decoder kwargs from kwargs
+ for key in kwargs_encoder:
+ del kwargs["encoder_" + key]
+ for key in kwargs_decoder:
+ del kwargs["decoder_" + key]
+
+ # Load and initialize the encoder and decoder
+ # The distinction between encoder and decoder at the model level is made
+ # by the value of the flag `is_decoder` that we need to set correctly.
+ encoder = kwargs_encoder.pop("model", None)
+ if encoder is None:
+ if encoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_encoder:
+ encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
+ )
+
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+ logger.info(
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
+ "from a decoder model. Cross-attention and causal mask are disabled."
+ )
+ encoder_config.is_decoder = False
+ encoder_config.add_cross_attention = False
+
+ kwargs_encoder["config"] = encoder_config
+
+ encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
+
+ decoder = kwargs_decoder.pop("model", None)
+ if decoder is None:
+ if decoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_decoder:
+ decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
+ )
+
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
+ logger.info(
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
+ )
+ decoder_config.is_decoder = True
+ decoder_config.add_cross_attention = True
+
+ kwargs_decoder["config"] = decoder_config
+
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
+ logger.warning(
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
+ )
+
+ decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
+
+ # instantiate config with corresponding kwargs
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
+ return cls(encoder=encoder, decoder=decoder, config=config)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ encoder_outputs: Optional[tuple[torch.FloatTensor]] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, Seq2SeqLMOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the
+ right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`.
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+ representation. This is useful if you want more control over how to convert `decoder_input_ids` indices
+ into associated vectors than the model's internal embedding lookup matrix.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0,
+ ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+
+ Examples:
+
+ ```python
+ >>> from transformers import EncoderDecoderModel, BertTokenizer
+ >>> import torch
+
+ >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
+ >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained(
+ ... "google-bert/bert-base-uncased", "google-bert/bert-base-uncased"
+ ... ) # initialize Bert2Bert from pre-trained checkpoints
+
+ >>> # training
+ >>> model.config.decoder_start_token_id = tokenizer.cls_token_id
+ >>> model.config.pad_token_id = tokenizer.pad_token_id
+ >>> model.config.vocab_size = model.config.decoder.vocab_size
+
+ >>> input_ids = tokenizer("This is a really long text", return_tensors="pt").input_ids
+ >>> labels = tokenizer("This is the corresponding summary", return_tensors="pt").input_ids
+ >>> outputs = model(input_ids=input_ids, labels=labels)
+ >>> loss, logits = outputs.loss, outputs.logits
+
+ >>> # save and load from pretrained
+ >>> model.save_pretrained("bert2bert")
+ >>> model = EncoderDecoderModel.from_pretrained("bert2bert")
+
+ >>> # generation
+ >>> generated = model.generate(input_ids)
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+ if "num_items_in_batch" in kwargs_encoder:
+ kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None)
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ **kwargs_encoder,
+ )
+ elif isinstance(encoder_outputs, tuple):
+ encoder_outputs = BaseModelOutput(*encoder_outputs)
+
+ encoder_hidden_states = encoder_outputs[0]
+
+ # optionally project encoder_hidden_states
+ if (
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+ if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+ if decoder_attention_mask is None:
+ decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=attention_mask,
+ inputs_embeds=decoder_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ past_key_values=past_key_values,
+ return_dict=return_dict,
+ **kwargs_decoder,
+ )
+
+ # Compute loss independent from decoder (as some shift the logits inside them)
+ loss = None
+ if labels is not None:
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ if loss is not None:
+ return (loss,) + decoder_outputs + encoder_outputs
+ else:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=decoder_outputs.logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+ def resize_token_embeddings(self, *args, **kwargs):
+ raise NotImplementedError(
+ "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
+ " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
+ " model.decoder.resize_token_embeddings(...))"
+ )
+
+
+__all__ = ["EncoderDecoderModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a27c23c3c69ae928c73273c9397d5f5aad2b1c0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
@@ -0,0 +1,901 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Classes to support Flax Encoder-Decoder architectures"""
+
+import os
+from typing import Optional, Union
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+from jax.random import PRNGKey
+
+from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
+from ...modeling_flax_utils import FlaxPreTrainedModel
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from ..auto.configuration_auto import AutoConfig
+from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM
+from .configuration_encoder_decoder import EncoderDecoderConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "EncoderDecoderConfig"
+
+ENCODER_DECODER_START_DOCSTRING = r"""
+ This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
+ encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
+ [`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`]
+ function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream
+ generative task, like summarization.
+
+ The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
+ tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
+ Tasks](https://huggingface.co/papers/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
+ Zhou, Wei Li, Peter J. Liu.
+
+ After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models
+ (see the examples for more information).
+
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Parameters:
+ config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+ENCODER_DECODER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
+ and prepending them with the `decoder_start_token_id`.
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.encoder.max_position_embeddings - 1]`.
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.decoder.max_position_embeddings - 1]`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
+"""
+
+ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.encoder.max_position_embeddings - 1]`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
+"""
+
+ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
+ Args:
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
+ and prepending them with the `decoder_start_token_id`.
+ encoder_outputs (`tuple(tuple(jnp.ndarray)`):
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.decoder.max_position_embeddings - 1]`.
+ past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a
+ plain tuple.
+"""
+
+
+class FlaxEncoderDecoderModule(nn.Module):
+ config: EncoderDecoderConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ encoder_config = self.config.encoder
+ decoder_config = self.config.decoder
+
+ # Copied from `modeling_hybrid_clip.py` with modifications.
+ from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING
+
+ encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class
+ decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class
+
+ self.encoder = encoder_module(encoder_config, dtype=self.dtype)
+ self.decoder = decoder_module(decoder_config, dtype=self.dtype)
+
+ # encoder outputs might need to be projected to different dimension for decoder
+ if (
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ self.enc_to_dec_proj = nn.Dense(
+ self.decoder.config.hidden_size,
+ kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
+ dtype=self.dtype,
+ )
+ else:
+ self.enc_to_dec_proj = None
+
+ def _get_encoder_module(self):
+ return self.encoder
+
+ def _get_projection_module(self):
+ return self.enc_to_dec_proj
+
+ def _get_decoder_module(self):
+ return self.decoder
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ position_ids,
+ decoder_position_ids,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ encoder_hidden_states = encoder_outputs[0]
+
+ # optionally project encoder_hidden_states
+ if self.enc_to_dec_proj is not None:
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ position_ids=decoder_position_ids,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return FlaxSeq2SeqLMOutput(
+ logits=decoder_outputs.logits,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
+class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
+ r"""
+ [`FlaxEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with
+ the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one as
+ decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
+ encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
+ """
+
+ config_class = EncoderDecoderConfig
+ base_model_prefix = "encoder_decoder"
+ module_class = FlaxEncoderDecoderModule
+
+ def __init__(
+ self,
+ config: EncoderDecoderConfig,
+ input_shape: Optional[tuple] = None,
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ if input_shape is None:
+ input_shape = ((1, 1), (1, 1))
+
+ if not _do_init:
+ raise ValueError(
+ "`FlaxEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
+ )
+
+ if config.decoder.cross_attention_hidden_size is not None:
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
+ raise ValueError(
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
+ )
+
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
+ encoder_input_shape, decoder_input_shape = input_shape
+
+ # init input tensors
+ input_ids = jnp.zeros(encoder_input_shape, dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+ decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
+ if not decoder_batch_size == batch_size:
+ raise ValueError(
+ f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder"
+ f" and {decoder_batch_size} for decoder."
+ )
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
+ )
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ random_params = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ position_ids,
+ decoder_position_ids,
+ )["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ def init_cache(self, batch_size, max_length, encoder_outputs):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
+ cross-attention of the decoder.
+ """
+ # init input variables to retrieve cache
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
+ )
+
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
+ decoder_module = module._get_decoder_module()
+ return decoder_module(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ position_ids=decoder_position_ids,
+ **kwargs,
+ )
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0),
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ decoder_position_ids=decoder_position_ids,
+ encoder_hidden_states=encoder_outputs[0],
+ init_cache=True,
+ method=_decoder_forward, # we only need to call the decoder to init the cache
+ )
+ return unfreeze(init_variables["cache"])
+
+ @add_start_docstrings(ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
+ def encode(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: Optional[dict] = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer
+
+ >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
+ >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
+
+ >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
+
+ >>> text = "My friends are cool but they eat too many carbs."
+ >>> input_ids = tokenizer.encode(text, return_tensors="np")
+ >>> encoder_outputs = model.encode(input_ids)
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+ if position_ids is None:
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
+ encode_module = module._get_encoder_module()
+ return encode_module(input_ids, attention_mask, position_ids, **kwargs)
+
+ outputs = self.module.apply(
+ {"params": params or self.params},
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ method=_encoder_forward,
+ )
+
+ if return_dict:
+ outputs = FlaxBaseModelOutput(
+ last_hidden_state=outputs.last_hidden_state,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ return outputs
+
+ @add_start_docstrings(ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+ def decode(
+ self,
+ decoder_input_ids,
+ encoder_outputs,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_position_ids: Optional[jnp.ndarray] = None,
+ past_key_values: Optional[dict] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: Optional[dict] = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer
+ >>> import jax.numpy as jnp
+
+ >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
+ >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
+
+ >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
+
+ >>> text = "My friends are cool but they eat too many carbs."
+ >>> input_ids = tokenizer.encode(text, max_length=1024, return_tensors="np")
+ >>> encoder_outputs = model.encode(input_ids)
+
+ >>> decoder_start_token_id = model.config.decoder.bos_token_id
+ >>> decoder_input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
+
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
+ >>> logits = outputs.logits
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ encoder_hidden_states = encoder_outputs[0]
+ if encoder_attention_mask is None:
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ batch_size, sequence_length = decoder_input_ids.shape
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ if decoder_position_ids is None:
+ if past_key_values is not None:
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
+
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+ )
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
+ # it can be changed by FlaxBartAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ def _decoder_forward(
+ module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
+ ):
+ projection_module = module._get_projection_module()
+ decoder_module = module._get_decoder_module()
+
+ # optionally project encoder_hidden_states
+ if projection_module is not None:
+ encoder_hidden_states = projection_module(encoder_hidden_states)
+
+ return decoder_module(
+ decoder_input_ids,
+ decoder_attention_mask,
+ decoder_position_ids,
+ encoder_hidden_states=encoder_hidden_states,
+ **kwargs,
+ )
+
+ outputs = self.module.apply(
+ inputs,
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ mutable=mutable,
+ method=_decoder_forward,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past = outputs
+ outputs["past_key_values"] = unfreeze(past["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past = outputs
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
+
+ return outputs
+
+ @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def __call__(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ decoder_input_ids: Optional[jnp.ndarray] = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ decoder_position_ids: Optional[jnp.ndarray] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: Optional[dict] = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer, GPT2Tokenizer
+
+ >>> # load a fine-tuned bert2gpt2 model
+ >>> model = FlaxEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16")
+ >>> # load input & output tokenizer
+ >>> tokenizer_input = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
+ >>> tokenizer_output = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
+
+ >>> article = '''Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members
+ >>> singing a racist chant. SAE's national chapter suspended the students,
+ >>> but University of Oklahoma President David Boren took it a step further,
+ >>> saying the university's affiliation with the fraternity is permanently done.'''
+
+ >>> input_ids = tokenizer_input(article, add_special_tokens=True, return_tensors="np").input_ids
+
+ >>> # use GPT2's eos_token as the pad as well as eos token
+ >>> model.config.eos_token_id = model.config.decoder.eos_token_id
+ >>> model.config.pad_token_id = model.config.eos_token_id
+
+ >>> sequences = model.generate(input_ids, num_beams=4, max_length=12).sequences
+
+ >>> summary = tokenizer_output.batch_decode(sequences, skip_special_tokens=True)[0]
+ >>> assert summary == "SAS Alpha Epsilon suspended Sigma Alpha Epsilon members"
+ ```
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ # prepare encoder inputs
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+ if position_ids is None:
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ # prepare decoder inputs
+ if decoder_input_ids is None:
+ raise ValueError(
+ "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must"
+ " be specified as an input argument."
+ )
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+ if decoder_position_ids is None:
+ batch_size, sequence_length = decoder_input_ids.shape
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+ )
+
+ # Handle any PRNG if needed
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
+
+ return self.module.apply(
+ {"params": params or self.params},
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ max_length,
+ attention_mask: Optional[jax.Array] = None,
+ decoder_attention_mask: Optional[jax.Array] = None,
+ encoder_outputs=None,
+ **kwargs,
+ ):
+ # initializing the cache
+ batch_size, seq_length = decoder_input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since the decoder uses a causal mask, those positions are masked anyways.
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if decoder_attention_mask is not None:
+ decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
+ else:
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
+ )
+
+ return {
+ "past_key_values": past_key_values,
+ "encoder_outputs": encoder_outputs,
+ "encoder_attention_mask": attention_mask,
+ "decoder_attention_mask": extended_attention_mask,
+ "decoder_position_ids": decoder_position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
+ return model_kwargs
+
+ @classmethod
+ def from_encoder_decoder_pretrained(
+ cls,
+ encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
+ decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
+ *model_args,
+ **kwargs,
+ ) -> FlaxPreTrainedModel:
+ r"""
+ Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
+ checkpoints.
+
+ Params:
+ encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
+ Information necessary to initiate the encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
+ Information necessary to initiate the decoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ model_args (remaining positional arguments, *optional*):
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`).
+
+ - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+ Example:
+
+ ```python
+ >>> from transformers import FlaxEncoderDecoderModel
+
+ >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
+ >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
+ >>> # saving model after fine-tuning
+ >>> model.save_pretrained("./bert2gpt2")
+ >>> # load fine-tuned model
+ >>> model = FlaxEncoderDecoderModel.from_pretrained("./bert2gpt2")
+ ```"""
+
+ kwargs_encoder = {
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
+ }
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ # remove encoder, decoder kwargs from kwargs
+ for key in kwargs_encoder:
+ del kwargs["encoder_" + key]
+ for key in kwargs_decoder:
+ del kwargs["decoder_" + key]
+
+ # Load and initialize the encoder and decoder
+ # The distinction between encoder and decoder at the model level is made
+ # by the value of the flag `is_decoder` that we need to set correctly.
+ encoder = kwargs_encoder.pop("model", None)
+ if encoder is None:
+ if encoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_encoder:
+ encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
+ )
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+ logger.info(
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
+ "from a decoder model. Cross-attention and causal mask are disabled."
+ )
+ encoder_config.is_decoder = False
+ encoder_config.add_cross_attention = False
+
+ kwargs_encoder["config"] = encoder_config
+
+ encoder = FlaxAutoModel.from_pretrained(
+ encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
+ )
+
+ decoder = kwargs_decoder.pop("model", None)
+ if decoder is None:
+ if decoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_decoder:
+ decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
+ )
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
+ logger.info(
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
+ )
+ decoder_config.is_decoder = True
+ decoder_config.add_cross_attention = True
+
+ kwargs_decoder["config"] = decoder_config
+
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
+ logger.warning(
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
+ )
+
+ decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
+
+ # instantiate config with corresponding kwargs
+ dtype = kwargs.pop("dtype", jnp.float32)
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
+
+ # init model
+ model = cls(config, dtype=dtype)
+ model.params["encoder"] = encoder.params
+ model.params["decoder"] = decoder.params
+
+ return model
+
+
+__all__ = ["FlaxEncoderDecoderModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e5343d200499e1f3b8ba26f8d70924c2999a2fc
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
@@ -0,0 +1,661 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Classes to support TF Encoder-Decoder architectures"""
+
+from __future__ import annotations
+
+import inspect
+import re
+import warnings
+
+import numpy as np
+import tensorflow as tf
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
+from ...modeling_tf_utils import (
+ TFCausalLanguageModelingLoss,
+ TFModelInputType,
+ TFPreTrainedModel,
+ get_initializer,
+ keras,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from ..auto.configuration_auto import AutoConfig
+from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
+from .configuration_encoder_decoder import EncoderDecoderConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "EncoderDecoderConfig"
+
+DEPRECATION_WARNING = (
+ "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the"
+ " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if"
+ " fine-tuning a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the"
+ " labels, no need to pass them yourself anymore."
+)
+
+ENCODER_DECODER_START_DOCSTRING = r"""
+ This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
+ encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
+ [`~TFAutoModel.from_pretrained`] function and the decoder is loaded via [`~TFAutoModelForCausalLM.from_pretrained`]
+ function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream
+ generative task, like summarization.
+
+ The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
+ tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
+ Tasks](https://huggingface.co/papers/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
+ Zhou, Wei Li, Peter J. Liu.
+
+ After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models
+ (see the examples for more information).
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ENCODER_DECODER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ Provide for sequence to sequence training to the decoder. Indices can be obtained using
+ [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
+ details.
+ decoder_attention_mask (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*):
+ This tuple must consist of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` (`tf.Tensor` of shape `({0}, hidden_size)`) is a tensor of hidden-states at the output
+ of the last layer of the encoder. Used in the cross-attention of the decoder.
+ past_key_values (`tuple(tuple(tf.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `({0})`.
+ inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ decoder_inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+ representation. This is useful if you want more control over how to convert `decoder_input_ids` indices
+ into associated vectors than the model's internal embedding lookup matrix.
+ labels (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0,
+ ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ If set to `True`, the model will return a [`~utils.Seq2SeqLMOutput`] instead of a plain tuple.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+ kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
+
+ - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
+ - With a *decoder_* prefix which will be input as `**decoder_kwargs`` for the decoder forward function.
+"""
+
+
+def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ if pad_token_id is None:
+ raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
+ pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
+
+ if decoder_start_token_id is None:
+ raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
+ decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
+
+ start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
+ shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids = tf.where(
+ shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
+ )
+
+ # "Verify that `labels` has only positive values and -100"
+ assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
+
+ # Make sure the assertion op is called by wrapping the result in an identity no-op
+ with tf.control_dependencies([assert_gte0]):
+ shifted_input_ids = tf.identity(shifted_input_ids)
+
+ return shifted_input_ids
+
+
+@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
+class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
+ r"""
+ [`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
+ of the base model classes of the library as encoder and another one as decoder when created with the
+ [`~TFAutoModel.from_pretrained`] class method for the encoder and [`~TFAutoModelForCausalLM.from_pretrained`] class
+ method for the decoder.
+ """
+
+ config_class = EncoderDecoderConfig
+ base_model_prefix = "encoder_decoder"
+ load_weight_prefix = "tf_encoder_decoder_model"
+
+ def __init__(
+ self,
+ config: PretrainedConfig | None = None,
+ encoder: TFPreTrainedModel | None = None,
+ decoder: TFPreTrainedModel | None = None,
+ ):
+ if config is None and (encoder is None or decoder is None):
+ raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
+ if config is None:
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
+ else:
+ if not isinstance(config, self.config_class):
+ raise ValueError(f"config: {config} has to be of type {self.config_class}")
+
+ if config.decoder.cross_attention_hidden_size is not None:
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
+ raise ValueError(
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
+ )
+
+ # initialize with config
+ super().__init__(config)
+
+ if encoder is None:
+ encoder = TFAutoModel.from_config(config.encoder, name="encoder")
+
+ if decoder is None:
+ decoder = TFAutoModelForCausalLM.from_config(config.decoder, name="decoder")
+
+ self.encoder = encoder
+ self.decoder = decoder
+
+ if self.encoder.config.to_dict() != self.config.encoder.to_dict():
+ logger.warning(
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
+ f" {self.config.encoder}"
+ )
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
+ logger.warning(
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
+ )
+
+ # make sure that the individual model's config refers to the shared config
+ # so that the updates to the config will be synced
+ self.encoder.config = self.config.encoder
+ self.decoder.config = self.config.decoder
+
+ # encoder outputs might need to be projected to different dimension for decoder
+ if (
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ self.enc_to_dec_proj = keras.layers.Dense(
+ units=self.decoder.config.hidden_size,
+ kernel_initializer=get_initializer(config.encoder.initializer_range),
+ name="enc_to_dec_proj",
+ )
+
+ if self.encoder.get_output_embeddings() is not None:
+ raise ValueError(
+ f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
+ )
+
+ decoder_signature = set(inspect.signature(self.decoder.call).parameters.keys())
+ if "encoder_hidden_states" not in decoder_signature:
+ raise ValueError(
+ "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
+ "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
+ )
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_input_embeddings(self):
+ return self.encoder.get_input_embeddings()
+
+ def get_output_embeddings(self):
+ return self.decoder.get_output_embeddings()
+
+ def set_output_embeddings(self, new_embeddings):
+ return self.decoder.set_output_embeddings(new_embeddings)
+
+ def tf_to_pt_weight_rename(self, tf_weight):
+ # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models
+ # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal.
+ # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption
+ # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's
+ # not the case, and I wasn't sure how else to go from the config to the correct MainLayer name!
+
+ # This override is only needed in the case where we're crossloading weights from PT. However, since weights are
+ # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file.
+ # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it
+ # or not.
+ encoder_model_type = self.config.encoder.model_type
+ if "encoder" in tf_weight and "decoder" not in tf_weight:
+ return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),)
+ else:
+ return (tf_weight,)
+
+ @classmethod
+ def from_encoder_decoder_pretrained(
+ cls,
+ encoder_pretrained_model_name_or_path: str | None = None,
+ decoder_pretrained_model_name_or_path: str | None = None,
+ *model_args,
+ **kwargs,
+ ) -> TFPreTrainedModel:
+ r"""
+ Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
+ checkpoints.
+
+
+ Params:
+ encoder_pretrained_model_name_or_path (`str`, *optional*):
+ Information necessary to initiate the encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *pytorch index checkpoint file* (e.g, `./pt_model/`). In this case,
+ `encoder_from_pt` should be set to `True`.
+
+ decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
+ Information necessary to initiate the decoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *pytorch checkpoint file* (e.g, `./pt_model/`). In this case,
+ `decoder_from_pt` should be set to `True`.
+
+ model_args (remaining positional arguments, *optional*):
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`).
+
+ - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+ Example:
+
+ ```python
+ >>> from transformers import TFEncoderDecoderModel
+
+ >>> # initialize a bert2gpt2 from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized
+ >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "openai-community/gpt2")
+ >>> # saving model after fine-tuning
+ >>> model.save_pretrained("./bert2gpt2")
+ >>> # load fine-tuned model
+ >>> model = TFEncoderDecoderModel.from_pretrained("./bert2gpt2")
+ ```"""
+
+ kwargs_encoder = {
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
+ }
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ # remove encoder, decoder kwargs from kwargs
+ for key in kwargs_encoder:
+ del kwargs["encoder_" + key]
+ for key in kwargs_decoder:
+ del kwargs["decoder_" + key]
+
+ # Load and initialize the encoder and decoder
+ # The distinction between encoder and decoder at the model level is made
+ # by the value of the flag `is_decoder` that we need to set correctly.
+ encoder = kwargs_encoder.pop("model", None)
+ if encoder is None:
+ if encoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_encoder:
+ encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+ logger.info(
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
+ "from a decoder model. Cross-attention and causal mask are disabled."
+ )
+ encoder_config.is_decoder = False
+ encoder_config.add_cross_attention = False
+
+ kwargs_encoder["config"] = encoder_config
+
+ kwargs_encoder["name"] = "encoder"
+ kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
+ encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
+
+ decoder = kwargs_decoder.pop("model", None)
+ if decoder is None:
+ if decoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_decoder:
+ decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
+ logger.info(
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
+ )
+ decoder_config.is_decoder = True
+ decoder_config.add_cross_attention = True
+
+ kwargs_decoder["config"] = decoder_config
+
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
+ logger.warning(
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
+ )
+
+ kwargs_decoder["name"] = "decoder"
+ kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
+ decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
+
+ # Make sure these 2 `keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
+ if encoder.name != "encoder":
+ raise ValueError("encoder model must be created with the name `encoder`.")
+ if decoder.name != "decoder":
+ raise ValueError("decoder model must be created with the name `decoder`.")
+
+ # instantiate config with corresponding kwargs
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
+ return cls(encoder=encoder, decoder=decoder, config=config)
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_input_ids: np.ndarray | tf.Tensor | None = None,
+ decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ encoder_outputs: np.ndarray | tf.Tensor | None = None,
+ past_key_values: tuple[tuple[tf.Tensor]] | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ **kwargs,
+ ) -> TFSeq2SeqLMOutput | tuple[tf.Tensor]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import TFEncoderDecoderModel, BertTokenizer
+
+ >>> # initialize a bert2gpt2 from a pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
+ >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
+
+ >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
+
+ >>> # forward
+ >>> input_ids = tokenizer.encode(
+ ... "Hello, my dog is cute", add_special_tokens=True, return_tensors="tf"
+ ... ) # Batch size 1
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
+
+ >>> # training
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)
+ >>> loss, logits = outputs.loss, outputs.logits
+
+ >>> # save and load from pretrained
+ >>> model.save_pretrained("bert2gpt2")
+ >>> model = TFEncoderDecoderModel.from_pretrained("bert2gpt2")
+
+ >>> # generation
+ >>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.bos_token_id)
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ # Let the user be responsible for the expected format.
+ if encoder_outputs is not None:
+ if return_dict and not isinstance(encoder_outputs, ModelOutput):
+ raise ValueError(
+ "If `return_dict=True` and `encoder_outputs` is provided, it should be an instance of "
+ f"`ModelOutput`. Got an instance {type(encoder_outputs)} for `encoder_outputs`."
+ )
+
+ if encoder_outputs is None:
+ encoder_inputs = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "inputs_embeds": inputs_embeds,
+ "output_attentions": output_attentions,
+ "output_hidden_states": output_hidden_states,
+ "return_dict": return_dict,
+ "training": training,
+ }
+
+ # Add arguments to encoder from `kwargs_encoder`
+ encoder_inputs.update(kwargs_encoder)
+
+ # Handle the case where the inputs are passed as a single dict which contains `labels`.
+ # The `labels` shouldn't be passed to `self.encoder` below, because it is a based model without this
+ # parameter (otherwise, an error occurs when `input_processing` is called inside `self.encoder.call()`).
+ if "labels" in encoder_inputs:
+ labels = encoder_inputs.pop("labels")
+
+ # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`.
+ if "decoder_input_ids" in encoder_inputs:
+ decoder_input_ids = encoder_inputs.pop("decoder_input_ids")
+ # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`.
+ if "decoder_attention_mask" in encoder_inputs:
+ decoder_attention_mask = encoder_inputs.pop("decoder_attention_mask")
+
+ encoder_outputs = self.encoder(**encoder_inputs)
+
+ encoder_hidden_states = encoder_outputs[0]
+
+ # optionally project encoder_hidden_states
+ if (
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+ if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ decoder_inputs = {
+ "input_ids": decoder_input_ids,
+ "attention_mask": decoder_attention_mask,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_attention_mask": attention_mask,
+ "inputs_embeds": decoder_inputs_embeds,
+ "output_attentions": output_attentions,
+ "output_hidden_states": output_hidden_states,
+ "use_cache": use_cache,
+ "past_key_values": past_key_values,
+ "return_dict": return_dict,
+ "training": training,
+ }
+
+ # Add arguments to decoder from `kwargs_decoder`
+ decoder_inputs.update(kwargs_decoder)
+
+ decoder_outputs = self.decoder(**decoder_inputs)
+
+ logits = decoder_outputs[0]
+
+ # Compute loss independent from decoder (as some shift the logits inside them)
+ loss = None
+ if labels is not None:
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
+ loss = self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ past_key_values = None
+ if use_cache:
+ past_key_values = decoder_outputs[1]
+ # The starting index of the remaining elements in `decoder_outputs`
+ start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
+
+ if not isinstance(encoder_outputs, tuple):
+ encoder_outputs = encoder_outputs.to_tuple()
+ output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
+ output = tuple(x for x in output if x is not None)
+ return output
+
+ return TFSeq2SeqLMOutput(
+ loss=loss,
+ logits=decoder_outputs.logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
+ ):
+ decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
+ decoder_attention_mask = decoder_inputs.get("attention_mask", None)
+ past_key_values = decoder_inputs.get("past_key_values")
+ if past_key_values is None:
+ past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2
+ input_dict = {
+ "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
+ "attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ "decoder_input_ids": decoder_inputs["input_ids"],
+ # TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
+ "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+ return input_dict
+
+ def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+ def resize_token_embeddings(self, *args, **kwargs):
+ raise NotImplementedError(
+ "Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported.Please use the"
+ " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
+ " model.decoder.resize_token_embeddings(...))"
+ )
+
+ def _reorder_cache(self, past, beam_idx):
+ # apply decoder cache reordering here
+ return self.decoder._reorder_cache(past, beam_idx)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "enc_to_dec_proj", None) is not None:
+ with tf.name_scope(self.enc_to_dec_proj.name):
+ self.enc_to_dec_proj.build([None, None, self.encoder.config.hidden_size])
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "decoder", None) is not None:
+ with tf.name_scope(self.decoder.name):
+ self.decoder.build(None)
+
+
+__all__ = ["TFEncoderDecoderModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8eac54d6ddcbdae2b8ca3771ae5540522f6f29da
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_esm import *
+ from .modeling_esm import *
+ from .modeling_esmfold import *
+ from .modeling_tf_esm import *
+ from .tokenization_esm import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/configuration_esm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/configuration_esm.py
new file mode 100644
index 0000000000000000000000000000000000000000..fabfb4ebd6d34a7f212af5e74a90c18d4a038156
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/configuration_esm.py
@@ -0,0 +1,365 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ESM model configuration"""
+
+from dataclasses import asdict, dataclass
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+# TODO Update this
+
+
+class EsmConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ESMModel`]. It is used to instantiate a ESM model
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the ESM
+ [facebook/esm-1b](https://huggingface.co/facebook/esm-1b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*):
+ Vocabulary size of the ESM model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ESMModel`].
+ mask_token_id (`int`, *optional*):
+ The index of the mask token in the vocabulary. This must be included in the config because of the
+ "mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens.
+ pad_token_id (`int`, *optional*):
+ The index of the padding token in the vocabulary. This must be included in the config because certain parts
+ of the ESM code use this instead of the attention mask.
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 1026):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query", "rotary"`.
+ For positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
+ is_decoder (`bool`, *optional*, defaults to `False`):
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ emb_layer_norm_before (`bool`, *optional*):
+ Whether to apply layer normalization after embeddings but before the main stem of the network.
+ token_dropout (`bool`, defaults to `False`):
+ When this is enabled, masked tokens are treated as if they had been dropped out by input dropout.
+
+ Examples:
+
+ ```python
+ >>> from transformers import EsmModel, EsmConfig
+
+ >>> # Initializing a ESM facebook/esm-1b style configuration
+ >>> configuration = EsmConfig(vocab_size=33)
+
+ >>> # Initializing a model from the configuration
+ >>> model = EsmModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "esm"
+
+ def __init__(
+ self,
+ vocab_size=None,
+ mask_token_id=None,
+ pad_token_id=None,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=1026,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ position_embedding_type="absolute",
+ use_cache=True,
+ emb_layer_norm_before=None,
+ token_dropout=False,
+ is_folding_model=False,
+ esmfold_config=None,
+ vocab_list=None,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.position_embedding_type = position_embedding_type
+ self.use_cache = use_cache
+ self.emb_layer_norm_before = emb_layer_norm_before
+ self.token_dropout = token_dropout
+ self.is_folding_model = is_folding_model
+ if is_folding_model:
+ if esmfold_config is None:
+ logger.info("No esmfold_config supplied for folding model, using default values.")
+ esmfold_config = EsmFoldConfig()
+ elif isinstance(esmfold_config, dict):
+ esmfold_config = EsmFoldConfig(**esmfold_config)
+ self.esmfold_config = esmfold_config
+ if vocab_list is None:
+ logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!")
+ self.vocab_list = get_default_vocab_list()
+ else:
+ self.vocab_list = vocab_list
+ else:
+ self.esmfold_config = None
+ self.vocab_list = None
+ if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False):
+ raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!")
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+ Returns:
+ `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = super().to_dict()
+ if isinstance(self.esmfold_config, EsmFoldConfig):
+ output["esmfold_config"] = self.esmfold_config.to_dict()
+ return output
+
+
+@dataclass
+class EsmFoldConfig:
+ esm_type: Optional[str] = None
+ fp16_esm: bool = True
+ use_esm_attn_map: bool = False
+ esm_ablate_pairwise: bool = False
+ esm_ablate_sequence: bool = False
+ esm_input_dropout: float = 0
+
+ embed_aa: bool = True
+ bypass_lm: bool = False
+
+ lddt_head_hid_dim: int = 128
+ trunk: "TrunkConfig" = None
+
+ def __post_init__(self):
+ if self.trunk is None:
+ self.trunk = TrunkConfig()
+ elif isinstance(self.trunk, dict):
+ self.trunk = TrunkConfig(**self.trunk)
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+ Returns:
+ `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = asdict(self)
+ output["trunk"] = self.trunk.to_dict()
+ return output
+
+
+@dataclass
+class TrunkConfig:
+ num_blocks: int = 48
+ sequence_state_dim: int = 1024
+ pairwise_state_dim: int = 128
+ sequence_head_width: int = 32
+ pairwise_head_width: int = 32
+ position_bins: int = 32
+ dropout: float = 0
+ layer_drop: float = 0
+ cpu_grad_checkpoint: bool = False
+ max_recycles: int = 4
+ chunk_size: Optional[int] = 128
+ structure_module: "StructureModuleConfig" = None
+
+ def __post_init__(self):
+ if self.structure_module is None:
+ self.structure_module = StructureModuleConfig()
+ elif isinstance(self.structure_module, dict):
+ self.structure_module = StructureModuleConfig(**self.structure_module)
+
+ if self.max_recycles <= 0:
+ raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.")
+ if self.sequence_state_dim % self.sequence_state_dim != 0:
+ raise ValueError(
+ "`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got"
+ f" {self.sequence_state_dim} and {self.sequence_state_dim}."
+ )
+ if self.pairwise_state_dim % self.pairwise_state_dim != 0:
+ raise ValueError(
+ "`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got"
+ f" {self.pairwise_state_dim} and {self.pairwise_state_dim}."
+ )
+
+ sequence_num_heads = self.sequence_state_dim // self.sequence_head_width
+ pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width
+
+ if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width:
+ raise ValueError(
+ "`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got"
+ f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}."
+ )
+ if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width:
+ raise ValueError(
+ "`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got"
+ f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}."
+ )
+ if self.pairwise_state_dim % 2 != 0:
+ raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.")
+
+ if self.dropout >= 0.4:
+ raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.")
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+ Returns:
+ `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = asdict(self)
+ output["structure_module"] = self.structure_module.to_dict()
+ return output
+
+
+@dataclass
+class StructureModuleConfig:
+ """
+ Args:
+ sequence_dim:
+ Single representation channel dimension
+ pairwise_dim:
+ Pair representation channel dimension
+ ipa_dim:
+ IPA hidden channel dimension
+ resnet_dim:
+ Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
+ num_heads_ipa:
+ Number of IPA heads
+ num_qk_points:
+ Number of query/key points to generate during IPA
+ num_v_points:
+ Number of value points to generate during IPA
+ dropout_rate:
+ Dropout rate used throughout the layer
+ num_blocks:
+ Number of structure module blocks
+ num_transition_layers:
+ Number of layers in the single representation transition (Alg. 23 lines 8-9)
+ num_resnet_blocks:
+ Number of blocks in the angle resnet
+ num_angles:
+ Number of angles to generate in the angle resnet
+ trans_scale_factor:
+ Scale of single representation transition hidden dimension
+ epsilon:
+ Small number used in angle resnet normalization
+ inf:
+ Large number used for attention masking
+ """
+
+ sequence_dim: int = 384
+ pairwise_dim: int = 128
+ ipa_dim: int = 16
+ resnet_dim: int = 128
+ num_heads_ipa: int = 12
+ num_qk_points: int = 4
+ num_v_points: int = 8
+ dropout_rate: float = 0.1
+ num_blocks: int = 8
+ num_transition_layers: int = 1
+ num_resnet_blocks: int = 2
+ num_angles: int = 7
+ trans_scale_factor: int = 10
+ epsilon: float = 1e-8
+ inf: float = 1e5
+
+ def to_dict(self):
+ return asdict(self)
+
+
+def get_default_vocab_list():
+ return (
+ "",
+ "",
+ "",
+ "",
+ "L",
+ "A",
+ "G",
+ "V",
+ "S",
+ "E",
+ "R",
+ "T",
+ "I",
+ "D",
+ "P",
+ "K",
+ "Q",
+ "N",
+ "F",
+ "Y",
+ "M",
+ "H",
+ "W",
+ "C",
+ "X",
+ "B",
+ "U",
+ "Z",
+ "O",
+ ".",
+ "-",
+ "",
+ "",
+ )
+
+
+__all__ = ["EsmConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py
new file mode 100644
index 0000000000000000000000000000000000000000..63d9344188cc83bd8ac4719db78def849afbda7f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py
@@ -0,0 +1,1058 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ESM model."""
+
+import math
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ MaskedLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.generic import OutputRecorder, check_model_inputs
+from .configuration_esm import EsmConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+def rotate_half(x):
+ x1, x2 = x.chunk(2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(x, cos, sin):
+ cos = cos[:, :, : x.shape[-2], :]
+ sin = sin[:, :, : x.shape[-2], :]
+
+ return (x * cos) + (rotate_half(x) * sin)
+
+
+def gelu(x):
+ """
+ This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results.
+ """
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
+
+
+def symmetrize(x):
+ "Make layer symmetric in final two dimensions, used for contact prediction."
+ return x + x.transpose(-1, -2)
+
+
+def average_product_correct(x):
+ "Perform average product correct, used for contact prediction."
+ a1 = x.sum(-1, keepdims=True)
+ a2 = x.sum(-2, keepdims=True)
+ a12 = x.sum((-1, -2), keepdims=True)
+
+ avg = a1 * a2
+ avg.div_(a12) # in-place to reduce memory
+ normalized = x - avg
+ return normalized
+
+
+class RotaryEmbedding(torch.nn.Module):
+ """
+ Rotary position embeddings based on those in
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
+ matrices which depend on their relative positions.
+ """
+
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim: int):
+ super().__init__()
+ # Generate and save the inverse frequency buffer (non trainable)
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ self._seq_len_cached = None
+ self._cos_cached = None
+ self._sin_cached = None
+
+ def _update_cos_sin_tables(self, x, seq_dimension=2):
+ seq_len = x.shape[seq_dimension]
+
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
+ self._seq_len_cached = seq_len
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
+ freqs = torch.outer(t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+ self._cos_cached = emb.cos()[None, None, :, :]
+ self._sin_cached = emb.sin()[None, None, :, :]
+
+ return self._cos_cached, self._sin_cached
+
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
+
+ return (
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
+ )
+
+
+class EsmContactPredictionHead(nn.Module):
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
+
+ def __init__(
+ self,
+ in_features: int,
+ bias=True,
+ eos_idx: int = 2,
+ ):
+ super().__init__()
+ self.in_features = in_features
+ self.eos_idx = eos_idx
+ self.regression = nn.Linear(in_features, 1, bias)
+ self.activation = nn.Sigmoid()
+
+ def forward(self, tokens, attentions):
+ # remove eos token attentions
+ eos_mask = tokens.ne(self.eos_idx).to(attentions)
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
+ attentions = attentions * eos_mask[:, None, None, :, :]
+ attentions = attentions[..., :-1, :-1]
+ # remove cls token attentions
+ attentions = attentions[..., 1:, 1:]
+ batch_size, layers, heads, seqlen, _ = attentions.size()
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
+
+ # features: batch x channels x tokens x tokens (symmetric)
+ attentions = attentions.to(
+ self.regression.weight.device
+ ) # attentions always float32, may need to convert to float16
+ attentions = average_product_correct(symmetrize(attentions))
+ attentions = attentions.permute(0, 2, 3, 1)
+ return self.activation(self.regression(attentions).squeeze(3))
+
+
+class EsmEmbeddings(nn.Module):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+
+ if config.emb_layer_norm_before:
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ else:
+ self.layer_norm = None
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+
+ self.padding_idx = config.pad_token_id
+ if self.position_embedding_type == "absolute":
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
+ )
+ self.token_dropout = config.token_dropout
+ self.mask_token_id = config.mask_token_id
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ inputs_embeds=None,
+ ):
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
+ else:
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an
+ # embedding_scale factor here.
+ embeddings = inputs_embeds
+
+ # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
+ # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
+ # masked tokens are treated as if they were selected for input dropout and zeroed out.
+ # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
+ # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
+ # This is analogous to the way that dropout layers scale down outputs during evaluation when not
+ # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
+ if self.token_dropout and input_ids is not None:
+ embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
+ mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
+ src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1]
+ mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
+ embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
+ embeddings.dtype
+ )
+
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings = embeddings + position_embeddings
+
+ if self.layer_norm is not None:
+ embeddings = self.layer_norm(embeddings)
+ if attention_mask is not None:
+ embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
+ # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
+ # embeddings = self.dropout(embeddings)
+ return embeddings
+
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+ """
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+
+ Args:
+ inputs_embeds: torch.Tensor
+
+ Returns: torch.Tensor
+ """
+ input_shape = inputs_embeds.size()[:-1]
+ sequence_length = input_shape[1]
+
+ position_ids = torch.arange(
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
+ )
+ return position_ids.unsqueeze(0).expand(input_shape)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ # ESM applies relative position embeddings and we don't copy from Llama
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+
+ if hasattr(module, "position_embedding_type") and module.position_embedding_type in [
+ "relative_key",
+ "relative_key_query",
+ ]:
+ seq_length = query.shape[2]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
+
+ if module.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
+ elif module.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
+ relative_position_scores = relative_position_scores_query + relative_position_scores_key
+
+ attn_weights = attn_weights + relative_position_scores
+
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class EsmSelfAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False):
+ super().__init__()
+ self.config = config
+
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = config.attention_probs_dropout_prob
+ self.position_embedding_type = position_embedding_type or getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ self.rotary_embeddings = None
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+ elif self.position_embedding_type == "rotary":
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
+
+ self.scaling = 1.0 # For BC we apply scaling before RoPE
+ self.is_decoder = config.is_decoder
+ self.layer_idx = layer_idx
+ self.is_causal = self.is_decoder and not is_cross_attention
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor]:
+ batch_size, seq_length = hidden_states.shape[:-1]
+ hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
+
+ query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ is_cross_attention = encoder_hidden_states is not None
+ current_states = encoder_hidden_states if is_cross_attention else hidden_states
+ attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
+ key_layer = self.key(current_states).view(hidden_shape).transpose(1, 2)
+ value_layer = self.value(current_states).view(hidden_shape).transpose(1, 2)
+
+ # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
+ # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
+ # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
+ # ESM code and fix rotary embeddings.
+ query_layer = query_layer * self.attention_head_size**-0.5
+
+ if self.position_embedding_type == "rotary":
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ if self.position_embedding_type in ["relative_key", "relative_key_query"]:
+ raise ValueError(
+ f"ESM {self.config._attn_implementation} attention does not support {self.position_embedding_type} embeddings. "
+ "Set attention explicitly to 'eager' with `model.set_attn_implementation('eager')`"
+ )
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ head_mask=head_mask,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
+ return attn_output, attn_weights
+
+
+class EsmSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states + input_tensor
+ return hidden_states
+
+
+class EsmAttention(nn.Module):
+ def __init__(self, config, layer_idx=None, is_cross_attention=False):
+ super().__init__()
+ self.self = EsmSelfAttention(config, layer_idx=layer_idx, is_cross_attention=is_cross_attention)
+ self.output = EsmSelfOutput(config)
+ self.pruned_heads = set()
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ hidden_states_ln = self.LayerNorm(hidden_states)
+ attn_output, _ = self.self(
+ hidden_states_ln,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ **kwargs,
+ )
+ attn_output = self.output(attn_output, hidden_states)
+ return attn_output
+
+
+class EsmIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = gelu(hidden_states)
+ return hidden_states
+
+
+class EsmOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states + input_tensor
+ return hidden_states
+
+
+class EsmLayer(GradientCheckpointingLayer):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = EsmAttention(config)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = EsmAttention(config, is_cross_attention=True)
+ self.intermediate = EsmIntermediate(config)
+ self.output = EsmOutput(config)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ attention_output = self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ **kwargs,
+ )
+
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise AttributeError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
+ " with cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+
+ attention_output = self.crossattention(
+ attention_output,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ **kwargs,
+ )
+
+ layer_output = self.feed_forward_chunk(attention_output)
+ return layer_output
+
+ def feed_forward_chunk(self, attention_output):
+ attention_output_ln = self.LayerNorm(attention_output)
+ intermediate_output = self.intermediate(attention_output_ln)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class EsmEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.gradient_checkpointing = False
+
+ @can_return_tuple
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ for i, layer_module in enumerate(self.layer):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ hidden_states = layer_module(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=layer_head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ **kwargs,
+ )
+
+ if self.emb_layer_norm_after:
+ hidden_states = self.emb_layer_norm_after(hidden_states)
+
+ return BaseModelOutputWithCrossAttentions(last_hidden_state=hidden_states)
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPooler
+class EsmPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+@auto_docstring
+class EsmPreTrainedModel(PreTrainedModel):
+ config: EsmConfig
+ base_model_prefix = "esm"
+ supports_gradient_checkpointing = True
+ accepts_loss_kwargs = False
+ _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings.weight"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+
+ _can_record_outputs = {
+ "hidden_states": EsmLayer,
+ "attentions": [OutputRecorder(EsmSelfAttention, index=1, layer_name="attention")],
+ "cross_attentions": [
+ OutputRecorder(EsmSelfAttention, index=1, layer_name="crossattention"),
+ ],
+ }
+
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, EsmLMHead):
+ module.bias.data.zero_()
+
+ def get_output_embeddings(self):
+ # NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
+ # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
+ return None
+
+
+@auto_docstring
+class EsmModel(EsmPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
+ all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ r"""
+ add_pooling_layer (bool, *optional*, defaults to `True`):
+ Whether to add a pooling layer
+ """
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = EsmEmbeddings(config)
+ self.encoder = EsmEncoder(config)
+
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
+
+ self.contact_head = EsmContactPredictionHead(
+ in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ position_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ inputs_embeds (`torch.FloatTensor` of shape `((batch_size, sequence_length), hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ )
+
+ if self.config._attn_implementation != "flash_attention_2":
+ batch_size, seq_length = inputs_embeds.shape[:-1]
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length)), device=inputs_embeds.device)
+
+ attention_mask: torch.Tensor = self.get_extended_attention_mask(
+ attention_mask, input_shape=(batch_size, seq_length)
+ )
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ **kwargs,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ )
+
+ def predict_contacts(self, tokens, attention_mask):
+ attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
+ attns = torch.stack(attns, dim=1) # Matches the original model layout
+ # In the original model, attentions for padding tokens are completely zeroed out.
+ # This makes no difference most of the time because the other tokens won't attend to them,
+ # but it does for the contact prediction task, which takes attentions as input,
+ # so we have to mimic that here.
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
+ return self.contact_head(tokens, attns)
+
+
+@auto_docstring
+class EsmForMaskedLM(EsmPreTrainedModel):
+ _tied_weights_keys = ["lm_head.decoder.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.esm = EsmModel(config, add_pooling_layer=False)
+ self.lm_head = EsmLMHead(config)
+
+ self.init_weights()
+
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head.decoder = new_embeddings
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+
+ outputs = self.esm(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ **kwargs,
+ )
+ sequence_output = outputs[0]
+ prediction_scores = self.lm_head(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+
+ labels = labels.to(prediction_scores.device)
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def predict_contacts(self, tokens, attention_mask):
+ return self.esm.predict_contacts(tokens, attention_mask=attention_mask)
+
+
+class EsmLMHead(nn.Module):
+ """ESM Head for masked language modeling."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ def forward(self, features, **kwargs):
+ x = self.dense(features)
+ x = gelu(x)
+ x = self.layer_norm(x)
+
+ # project back to size of vocabulary with bias
+ x = self.decoder(x) + self.bias
+ return x
+
+
+@auto_docstring(
+ custom_intro="""
+ ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """
+)
+class EsmForSequenceClassification(EsmPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.esm = EsmModel(config, add_pooling_layer=False)
+ self.classifier = EsmClassificationHead(config)
+
+ self.init_weights()
+
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ outputs = self.esm(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ **kwargs,
+ )
+ sequence_output = outputs[0]
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class EsmForTokenClassification(EsmPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.esm = EsmModel(config, add_pooling_layer=False)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+
+ outputs = self.esm(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ **kwargs,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+
+ labels = labels.to(logits.device)
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class EsmClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(self, features, **kwargs):
+ x = features[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = torch.tanh(x)
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+def create_position_ids_from_input_ids(input_ids, padding_idx):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+ are ignored. This is modified from fairseq's `utils.make_positions`.
+
+ Args:
+ x: torch.Tensor x:
+
+ Returns: torch.Tensor
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = input_ids.ne(padding_idx).int()
+ incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
+ return incremental_indices.long() + padding_idx
+
+
+__all__ = [
+ "EsmForMaskedLM",
+ "EsmForSequenceClassification",
+ "EsmForTokenClassification",
+ "EsmModel",
+ "EsmPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_esmfold.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_esmfold.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bc1f0dbdc701991a4109ddbc617eb1b3769c6a1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_esmfold.py
@@ -0,0 +1,2309 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+import sys
+from collections.abc import Sequence
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import LayerNorm
+
+from ...integrations.deepspeed import is_deepspeed_available
+from ...modeling_outputs import ModelOutput
+from ...utils import (
+ ContextManagers,
+ auto_docstring,
+ is_scipy_available,
+ logging,
+)
+from .modeling_esm import EsmModel, EsmPreTrainedModel
+from .openfold_utils import (
+ OFProtein,
+ Rigid,
+ Rotation,
+ atom14_to_atom37,
+ chunk_layer,
+ compute_predicted_aligned_error,
+ compute_tm,
+ frames_and_literature_positions_to_atom14_pos,
+ make_atom14_masks,
+ residue_constants,
+ to_pdb,
+ torsion_angles_to_frames,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`EsmForProteinFoldingOutput`].
+ """
+)
+class EsmForProteinFoldingOutput(ModelOutput):
+ r"""
+ frames (`torch.FloatTensor`):
+ Output frames.
+ sidechain_frames (`torch.FloatTensor`):
+ Output sidechain frames.
+ unnormalized_angles (`torch.FloatTensor`):
+ Predicted unnormalized backbone and side chain torsion angles.
+ angles (`torch.FloatTensor`):
+ Predicted backbone and side chain torsion angles.
+ positions (`torch.FloatTensor`):
+ Predicted positions of the backbone and side chain atoms.
+ states (`torch.FloatTensor`):
+ Hidden states from the protein folding trunk.
+ s_s (`torch.FloatTensor`):
+ Per-residue embeddings derived by concatenating the hidden states of each layer of the ESM-2 LM stem.
+ s_z (`torch.FloatTensor`):
+ Pairwise residue embeddings.
+ distogram_logits (`torch.FloatTensor`):
+ Input logits to the distogram used to compute residue distances.
+ lm_logits (`torch.FloatTensor`):
+ Logits output by the ESM-2 protein language model stem.
+ aatype (`torch.FloatTensor`):
+ Input amino acids (AlphaFold2 indices).
+ atom14_atom_exists (`torch.FloatTensor`):
+ Whether each atom exists in the atom14 representation.
+ residx_atom14_to_atom37 (`torch.FloatTensor`):
+ Mapping between atoms in the atom14 and atom37 representations.
+ residx_atom37_to_atom14 (`torch.FloatTensor`):
+ Mapping between atoms in the atom37 and atom14 representations.
+ atom37_atom_exists (`torch.FloatTensor`):
+ Whether each atom exists in the atom37 representation.
+ residue_index (`torch.FloatTensor`):
+ The index of each residue in the protein chain. Unless internal padding tokens are used, this will just be
+ a sequence of integers from 0 to `sequence_length`.
+ lddt_head (`torch.FloatTensor`):
+ Raw outputs from the lddt head used to compute plddt.
+ plddt (`torch.FloatTensor`):
+ Per-residue confidence scores. Regions of low confidence may indicate areas where the model's prediction is
+ uncertain, or where the protein structure is disordered.
+ ptm_logits (`torch.FloatTensor`):
+ Raw logits used for computing ptm.
+ ptm (`torch.FloatTensor`):
+ TM-score output representing the model's high-level confidence in the overall structure.
+ aligned_confidence_probs (`torch.FloatTensor`):
+ Per-residue confidence scores for the aligned structure.
+ predicted_aligned_error (`torch.FloatTensor`):
+ Predicted error between the model's prediction and the ground truth.
+ max_predicted_aligned_error (`torch.FloatTensor`):
+ Per-sample maximum predicted error.
+ """
+
+ frames: Optional[torch.FloatTensor] = None
+ sidechain_frames: Optional[torch.FloatTensor] = None
+ unnormalized_angles: Optional[torch.FloatTensor] = None
+ angles: Optional[torch.FloatTensor] = None
+ positions: Optional[torch.FloatTensor] = None
+ states: Optional[torch.FloatTensor] = None
+ s_s: Optional[torch.FloatTensor] = None
+ s_z: Optional[torch.FloatTensor] = None
+ distogram_logits: Optional[torch.FloatTensor] = None
+ lm_logits: Optional[torch.FloatTensor] = None
+ aatype: Optional[torch.FloatTensor] = None
+ atom14_atom_exists: Optional[torch.FloatTensor] = None
+ residx_atom14_to_atom37: Optional[torch.FloatTensor] = None
+ residx_atom37_to_atom14: Optional[torch.FloatTensor] = None
+ atom37_atom_exists: Optional[torch.FloatTensor] = None
+ residue_index: Optional[torch.FloatTensor] = None
+ lddt_head: Optional[torch.FloatTensor] = None
+ plddt: Optional[torch.FloatTensor] = None
+ ptm_logits: Optional[torch.FloatTensor] = None
+ ptm: Optional[torch.FloatTensor] = None
+ aligned_confidence_probs: Optional[torch.FloatTensor] = None
+ predicted_aligned_error: Optional[torch.FloatTensor] = None
+ max_predicted_aligned_error: Optional[torch.FloatTensor] = None
+
+
+def is_fp16_enabled(device_type):
+ # Autocast world
+ autocast_dtype = (
+ torch.get_autocast_dtype(device_type)
+ if hasattr(torch, "get_autocast_dtype")
+ else torch.get_autocast_gpu_dtype()
+ )
+ fp16_enabled = autocast_dtype == torch.float16
+ fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
+
+ return fp16_enabled
+
+
+def is_deepspeed_initialized():
+ if is_deepspeed_available():
+ return False
+ else:
+ try:
+ import deepspeed
+
+ # This is not available in all DeepSpeed versions.
+ return deepspeed.utils.is_initialized()
+ except Exception:
+ return False
+
+
+def collate_dense_tensors(samples: list[torch.Tensor], pad_v: float = 0) -> torch.Tensor:
+ """
+ Takes a list of tensors with the following dimensions:
+ [(d_11, ..., d_1K),
+ (d_21, ..., d_2K), ..., (d_N1, ..., d_NK)]
+ and stack + pads them into a single tensor of:
+ (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
+ """
+ if len(samples) == 0:
+ return torch.Tensor()
+ if len({x.dim() for x in samples}) != 1:
+ raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}")
+ (device,) = tuple({x.device for x in samples}) # assumes all on same device
+ max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
+ result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device)
+ result.fill_(pad_v)
+ for i in range(len(samples)):
+ result_i = result[i]
+ t = samples[i]
+ result_i[tuple(slice(0, k) for k in t.shape)] = t
+ return result
+
+
+def flatten_final_dims(t: torch.Tensor, no_dims: int):
+ return t.reshape(t.shape[:-no_dims] + (-1,))
+
+
+def permute_final_dims(tensor: torch.Tensor, inds: list[int]):
+ zero_index = -1 * len(inds)
+ first_inds = list(range(len(tensor.shape[:zero_index])))
+ return tensor.permute(first_inds + [zero_index + i for i in inds])
+
+
+def dict_multimap(fn, dicts):
+ first = dicts[0]
+ new_dict = {}
+ for k, v in first.items():
+ all_v = [d[k] for d in dicts]
+ if isinstance(v, dict):
+ new_dict[k] = dict_multimap(fn, all_v)
+ else:
+ new_dict[k] = fn(all_v)
+
+ return new_dict
+
+
+def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
+ shape = weights.shape
+ scale = scale / max(1, shape[1])
+
+ if not is_scipy_available():
+ logger.warning(
+ "This init requires scipy, but scipy was not found, default to an approximation that might not be"
+ " equivalent."
+ )
+ std = math.sqrt(scale)
+ torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std)
+
+ else:
+ from scipy.stats import truncnorm
+
+ std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1)
+ samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel())
+ samples = np.reshape(samples, shape)
+ weights.copy_(torch.tensor(samples, device=weights.device))
+
+
+def ipa_point_weights_init_(weights):
+ with torch.no_grad():
+ softplus_inverse_1 = 0.541324854612918
+ weights.fill_(softplus_inverse_1)
+
+
+class EsmFoldLinear(nn.Linear):
+ """
+ A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear.
+
+ Implements the initializers in 1.11.4, plus some additional ones found in the code.
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ bias: bool = True,
+ init: str = "default",
+ init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
+ ):
+ """
+ Args:
+ in_dim:
+ The final dimension of inputs to the layer
+ out_dim:
+ The final dimension of layer outputs
+ bias:
+ Whether to learn an additive bias. True by default
+ init:
+ The initializer to use. Choose from:
+
+ "default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal
+ distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal":
+ Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0
+
+ Overridden by init_fn if the latter is not None.
+ init_fn:
+ A custom initializer taking weight and bias as inputs. Overrides init if not None.
+ """
+ super().__init__(in_dim, out_dim, bias=bias)
+
+ if bias:
+ with torch.no_grad():
+ self.bias.fill_(0)
+ self.init = init
+ self.init_fn = init_fn
+
+ if init not in ["default", "relu", "glorot", "gating", "normal", "final"]:
+ raise ValueError("Invalid init string.")
+
+
+class EsmFoldLayerNorm(nn.Module):
+ def __init__(self, c_in, eps=1e-5):
+ super().__init__()
+
+ self.c_in = (c_in,)
+ self.eps = eps
+
+ self.weight = nn.Parameter(torch.ones(c_in))
+ self.bias = nn.Parameter(torch.zeros(c_in))
+
+ def forward(self, x):
+ d = x.dtype
+ if d is torch.bfloat16 and not is_deepspeed_initialized():
+ with torch.autocast(device_type="cuda", enabled=False):
+ out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps)
+ else:
+ out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps)
+
+ return out
+
+
+@torch.jit.ignore
+def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
+ """
+ Softmax, but without automatic casting to fp32 when the input is of type bfloat16
+ """
+ d = t.dtype
+ if d is torch.bfloat16 and not is_deepspeed_initialized():
+ with torch.autocast(device_type="cuda", enabled=False):
+ s = torch.nn.functional.softmax(t, dim=dim)
+ else:
+ s = torch.nn.functional.softmax(t, dim=dim)
+
+ return s
+
+
+class EsmFoldAttention(nn.Module):
+ """
+ Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors.
+ """
+
+ def __init__(
+ self,
+ c_q: int,
+ c_k: int,
+ c_v: int,
+ c_hidden: int,
+ no_heads: int,
+ gating: bool = True,
+ ):
+ """
+ Args:
+ c_q:
+ Input dimension of query data
+ c_k:
+ Input dimension of key data
+ c_v:
+ Input dimension of value data
+ c_hidden:
+ Per-head hidden dimension
+ no_heads:
+ Number of attention heads
+ gating:
+ Whether the output should be gated using query data
+ """
+ super().__init__()
+
+ self.c_q = c_q
+ self.c_k = c_k
+ self.c_v = c_v
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.gating = gating
+
+ # DISCREPANCY: c_hidden is not the per-head channel dimension, as
+ # stated in the supplement, but the overall channel dimension.
+
+ self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")
+ self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")
+ self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")
+ self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final")
+
+ self.linear_g = None
+ if self.gating:
+ self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating")
+
+ self.sigmoid = nn.Sigmoid()
+
+ def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # [*, Q/K/V, H * C_hidden]
+ q = self.linear_q(q_x)
+ k = self.linear_k(kv_x)
+ v = self.linear_v(kv_x)
+
+ # [*, Q/K, H, C_hidden]
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
+ k = k.view(k.shape[:-1] + (self.no_heads, -1))
+ v = v.view(v.shape[:-1] + (self.no_heads, -1))
+
+ # [*, H, Q/K, C_hidden]
+ q = q.transpose(-2, -3)
+ k = k.transpose(-2, -3)
+ v = v.transpose(-2, -3)
+
+ q /= math.sqrt(self.c_hidden)
+
+ return q, k, v
+
+ def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
+ if self.linear_g is not None:
+ g = self.sigmoid(self.linear_g(q_x))
+
+ # [*, Q, H, C_hidden]
+ g = g.view(g.shape[:-1] + (self.no_heads, -1))
+ o = o * g
+
+ # [*, Q, H * C_hidden]
+ o = flatten_final_dims(o, 2)
+
+ # [*, Q, C_q]
+ o = self.linear_o(o)
+
+ return o
+
+ def forward(
+ self,
+ q_x: torch.Tensor,
+ kv_x: torch.Tensor,
+ biases: Optional[list[torch.Tensor]] = None,
+ use_memory_efficient_kernel: bool = False,
+ use_lma: bool = False,
+ lma_q_chunk_size: int = 1024,
+ lma_kv_chunk_size: int = 4096,
+ use_flash: bool = False,
+ flash_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ q_x:
+ [*, Q, C_q] query data
+ kv_x:
+ [*, K, C_k] key data
+ biases:
+ List of biases that broadcast to [*, H, Q, K]
+ use_memory_efficient_kernel:
+ Whether to use a custom memory-efficient attention kernel. This should be the default choice for most.
+ If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead
+ use_lma:
+ Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a
+ stock PyTorch implementation is used instead
+ lma_q_chunk_size:
+ Query chunk size (for LMA)
+ lma_kv_chunk_size:
+ Key/Value chunk size (for LMA)
+ Returns
+ [*, Q, C_q] attention update
+ """
+ if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
+ raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided")
+
+ if use_flash and biases is not None:
+ raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead")
+
+ attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
+ if sum(attn_options) > 1:
+ raise ValueError("Choose at most one alternative attention algorithm")
+
+ if biases is None:
+ biases = []
+
+ # [*, H, Q/K, C_hidden]
+ query, key, value = self._prep_qkv(q_x, kv_x)
+ key = permute_final_dims(key, (1, 0))
+
+ # [*, H, Q, K]
+ output = torch.matmul(query, key)
+ for b in biases:
+ output += b
+ output = softmax_no_cast(output, -1)
+
+ # [*, H, Q, C_hidden]
+ output = torch.matmul(output, value)
+ output = output.transpose(-2, -3)
+ output = self._wrap_up(output, q_x)
+
+ return output
+
+
+class EsmFoldTriangleAttention(nn.Module):
+ def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9):
+ """
+ Args:
+ c_in:
+ Input channel dimension
+ c_hidden:
+ Overall hidden channel dimension (not per-head)
+ no_heads:
+ Number of attention heads
+ """
+ super().__init__()
+
+ self.c_in = c_in
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.starting = starting
+ self.inf = inf
+
+ self.layer_norm = LayerNorm(self.c_in)
+
+ self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal")
+
+ self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
+
+ @torch.jit.ignore
+ def _chunk(
+ self,
+ x: torch.Tensor,
+ biases: list[torch.Tensor],
+ chunk_size: int,
+ use_memory_efficient_kernel: bool = False,
+ use_lma: bool = False,
+ inplace_safe: bool = False,
+ ) -> torch.Tensor:
+ "triangle! triangle!"
+ mha_inputs = {
+ "q_x": x,
+ "kv_x": x,
+ "biases": biases,
+ }
+
+ return chunk_layer(
+ partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma),
+ mha_inputs,
+ chunk_size=chunk_size,
+ no_batch_dims=len(x.shape[:-2]),
+ _out=x if inplace_safe else None,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ chunk_size: Optional[int] = None,
+ use_memory_efficient_kernel: bool = False,
+ use_lma: bool = False,
+ inplace_safe: bool = False,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ [*, I, J, C_in] input tensor (e.g. the pair representation)
+ Returns:
+ [*, I, J, C_in] output tensor
+ """
+ if mask is None:
+ # [*, I, J]
+ mask = x.new_ones(
+ x.shape[:-1],
+ )
+
+ if not self.starting:
+ x = x.transpose(-2, -3)
+ mask = mask.transpose(-1, -2)
+
+ # [*, I, J, C_in]
+ x = self.layer_norm(x)
+
+ # [*, I, 1, 1, J]
+ mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
+
+ # [*, H, I, J]
+ triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
+
+ # [*, 1, H, I, J]
+ triangle_bias = triangle_bias.unsqueeze(-4)
+
+ biases = [mask_bias, triangle_bias]
+
+ if chunk_size is not None:
+ x = self._chunk(
+ x,
+ biases,
+ chunk_size,
+ use_memory_efficient_kernel=use_memory_efficient_kernel,
+ use_lma=use_lma,
+ inplace_safe=inplace_safe,
+ )
+ else:
+ x = self.mha(
+ q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma
+ )
+
+ if not self.starting:
+ x = x.transpose(-2, -3)
+
+ return x
+
+
+class EsmFoldTriangleMultiplicativeUpdate(nn.Module):
+ """
+ Implements Algorithms 11 and 12.
+ """
+
+ def __init__(self, config, _outgoing=True):
+ super().__init__()
+ c_hidden = config.pairwise_state_dim
+ self._outgoing = _outgoing
+
+ self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden)
+ self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
+ self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden)
+ self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
+ self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
+ self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final")
+
+ self.layer_norm_in = LayerNorm(c_hidden)
+ self.layer_norm_out = LayerNorm(c_hidden)
+
+ self.sigmoid = nn.Sigmoid()
+
+ def _combine_projections(
+ self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None
+ ) -> torch.Tensor:
+ if self._outgoing:
+ a = permute_final_dims(a, (2, 0, 1))
+ b = permute_final_dims(b, (2, 1, 0))
+ else:
+ a = permute_final_dims(a, (2, 1, 0))
+ b = permute_final_dims(b, (2, 0, 1))
+
+ if _inplace_chunk_size is not None:
+ # To be replaced by torch vmap
+ for i in range(0, a.shape[-3], _inplace_chunk_size):
+ a_chunk = a[..., i : i + _inplace_chunk_size, :, :]
+ b_chunk = b[..., i : i + _inplace_chunk_size, :, :]
+ a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul(
+ a_chunk,
+ b_chunk,
+ )
+
+ p = a
+ else:
+ p = torch.matmul(a, b)
+
+ return permute_final_dims(p, (1, 2, 0))
+
+ def _inference_forward(
+ self,
+ z: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ inplace_chunk_size: Optional[int] = None,
+ with_add: bool = True,
+ ):
+ """
+ Args:
+ z:
+ A [*, N, N, C_z] pair representation
+ mask:
+ A [*, N, N] pair mask
+ inplace_chunk_size:
+ Size of chunks used in the main computation. Increase to trade memory for speed.
+ with_add:
+ If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update).
+ Returns:
+ A reference to the overwritten z
+
+ More memory-efficient, inference-only version of the forward function. Uses in-place operations, fusion of the
+ addition that happens after this module in the Evoformer, a smidge of recomputation, and a cache of overwritten
+ values to lower peak memory consumption of this module from 5x the size of the input tensor z to 2.5x its size.
+ Useful for inference on extremely long sequences.
+
+ It works as follows. We will make reference to variables used in the default forward implementation below.
+ Naively, triangle multiplication attention requires the manifestation of 5 tensors the size of z: 1) z, the
+ "square" input tensor, 2) a, the first projection of z, 3) b, the second projection of b, 4) g, a z-sized mask,
+ and 5) a z-sized tensor for intermediate computations. For large N, this is prohibitively expensive; for
+ N=4000, for example, z is more than 8GB alone. To avoid this problem, we compute b, g, and all intermediate
+ tensors in small chunks, noting that the chunks required to compute a chunk of the output depend only on the
+ tensor a and corresponding vertical and horizontal chunks of z. This suggests an algorithm that loops over
+ pairs of chunks of z: hereafter "columns" and "rows" of z, even though each "column" and "row" in fact contains
+ inplace_chunk_size contiguous true columns and rows of z. Writing output chunks to a new tensor would bring
+ total memory consumption down to 3x the size of z. However, more memory can be saved by writing output chunks
+ directly to z in-place. WLOG, we choose to write output chunks vertically, overwriting the ith "column" of z at
+ the end of the ith iteration of the main loop. Despite this overwriting, the ith column is always one column
+ ahead of previously overwritten columns and can be recovered directly from z. After the first iteration,
+ however, the ith row of z is always at least partially overwritten. For this reason, we introduce the z-cache,
+ a tensor one-half the size of z. The z-cache initially contains the left half (2nd and 3rd quadrants) of z. For
+ 0 < i < N/2, the missing left part of the ith row of z is recovered from this cache at the beginning of the ith
+ iteration. Once i exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th quadrants of z instead.
+ Though the 3rd quadrant of the original z is entirely overwritten at this point, it can be recovered from the
+ z-cache itself. Thereafter, the ith row of z can be recovered in its entirety from the reoriented z-cache.
+ After the final iteration, z has been completely overwritten and contains the triangular multiplicative update.
+ If with_add is True, it instead contains the sum of z and the triangular multiplicative update. In either case,
+ peak memory consumption is just 2.5x the size of z, disregarding memory used for chunks and other small
+ variables.
+ """
+ if mask is None:
+ mask = z.new_ones(z.shape[:-1])
+
+ mask = mask.unsqueeze(-1)
+
+ def compute_projection_helper(pair, mask, a=True):
+ if a:
+ linear_g = self.linear_a_g
+ linear_p = self.linear_a_p
+ else:
+ linear_g = self.linear_b_g
+ linear_p = self.linear_b_p
+
+ pair = self.layer_norm_in(pair)
+ p = linear_g(pair)
+ p.sigmoid_()
+ p *= linear_p(pair)
+ p *= mask
+ p = permute_final_dims(p, (2, 0, 1))
+ return p
+
+ def compute_projection(pair, mask, a=True, chunked=True):
+ need_transpose = self._outgoing ^ a
+ if not chunked:
+ p = compute_projection_helper(pair, mask, a)
+ if need_transpose:
+ p = p.transpose(-1, -2)
+ else:
+ # This computation is chunked so as not to exceed our 2.5x
+ # budget with a large intermediate tensor
+ linear_g = self.linear_a_g if a else self.linear_b_g
+ c = linear_g.bias.shape[-1]
+ out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1]
+ p = pair.new_zeros(out_shape)
+ for i in range(0, pair.shape[-3], inplace_chunk_size):
+ pair_chunk = pair[..., i : i + inplace_chunk_size, :, :]
+ pair_chunk = compute_projection_helper(
+ pair[..., i : i + inplace_chunk_size, :, :],
+ mask[..., i : i + inplace_chunk_size, :, :],
+ a,
+ )
+ if need_transpose:
+ pair_chunk = pair_chunk.transpose(-1, -2)
+ p[..., i : i + inplace_chunk_size] = pair_chunk
+ else:
+ p[..., i : i + inplace_chunk_size, :] = pair_chunk
+
+ del pair_chunk
+
+ return p
+
+ # We start by fully manifesting a. In addition to the input, this
+ # brings total memory consumption to 2x z (disregarding size of chunks)
+ # [*, N, N, c]
+ a = compute_projection(z, mask, True, chunked=True)
+
+ if inplace_chunk_size is not None:
+ n = a.shape[-1]
+ half_n = n // 2 + n % 2
+ row_dim = -3
+ col_dim = -2
+ b_chunk_dim = row_dim if self._outgoing else col_dim
+
+ def empty_slicer(t):
+ return [slice(None) for _ in t.shape]
+
+ def slice_tensor(t, start, end, dim):
+ # Slices start:end from the dim dimension of t
+ s = empty_slicer(t)
+ s[dim] = slice(start, end)
+ return t[s]
+
+ def flip_z_cache_(z_cache, z):
+ # "Reorient" the z_cache (see below), filling it with quadrants
+ # 3---recovered from the z_cache---and 4---recovered from z---
+ # of the input tensor z.
+ quadrant_3 = slice_tensor(z_cache, half_n, None, row_dim)
+ z_cache = z_cache.transpose(row_dim, col_dim)
+
+ # If n is odd, we need to shrink the z_cache by one row
+ z_cache = z_cache[..., : (n // 2), :, :]
+
+ # Move the 3rd quadrant of z into the
+ first_half_slicer = empty_slicer(z_cache)
+ first_half_slicer[col_dim] = slice(0, half_n)
+ z_cache[first_half_slicer] = quadrant_3
+
+ # Get the fourth quadrant of z
+ quadrant_4 = slice_tensor(z, half_n, None, row_dim)
+ quadrant_4 = slice_tensor(quadrant_4, half_n, None, col_dim)
+
+ # Insert said quadrant into the rotated z-cache
+ quadrant_3_slicer = empty_slicer(z_cache)
+ quadrant_3_slicer[col_dim] = slice(half_n, None)
+
+ z_cache[quadrant_3_slicer] = quadrant_4
+
+ return z_cache
+
+ # Initialize the z cache to the left half of z.
+ z_cache_shape = list(z.shape)
+ z_cache_shape[col_dim] = half_n
+ z_cache = z.new_zeros(z_cache_shape)
+ z_cache_slicer = empty_slicer(z_cache)
+ z_cache_slicer[col_dim] = slice(0, half_n)
+ z_cache.copy_(z[z_cache_slicer])
+ z_cache_rotated = False
+
+ # We need to reorient the z-cache at the halfway point, and we
+ # don't want a single chunk to straddle that point. We contract one
+ # of the chunks in the middle to address that problem.
+ i_range = list(range(0, half_n, inplace_chunk_size))
+ initial_offsets = [i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])]
+ after_half = list(range(half_n, n, inplace_chunk_size))
+ after_half_offsets = [inplace_chunk_size for _ in after_half]
+ combined_range_with_offsets = zip(i_range + after_half, initial_offsets + after_half_offsets)
+ for i, offset in combined_range_with_offsets:
+ if not z_cache_rotated and i >= half_n:
+ z_cache = flip_z_cache_(z_cache, z)
+ z_cache_rotated = True
+
+ z_chunk_b = slice_tensor(z, i, i + offset, b_chunk_dim)
+ mask_chunk = slice_tensor(mask, i, i + offset, b_chunk_dim)
+
+ z_chunk_b = z_chunk_b.clone()
+ if b_chunk_dim == col_dim:
+ z_chunk_b = slice_tensor(z, i, i + offset, col_dim)
+ else: # b_chunk_dim == row_dim
+ # In this case, the b-dimension (b_chunk_dim) is partially
+ # overwritten at the end of each iteration. We need to
+ # restore the missing component from the z-cache.
+ if not z_cache_rotated:
+ z_chunk_slicer = empty_slicer(z_chunk_b)
+ z_chunk_slicer[col_dim] = slice(0, half_n)
+ z_chunk_b[z_chunk_slicer] = slice_tensor(z_cache, i, i + offset, row_dim)
+ else:
+ z_cache_offset = i - half_n
+ z_chunk_b = slice_tensor(z_cache, z_cache_offset, z_cache_offset + offset, row_dim)
+
+ b_chunk = compute_projection(z_chunk_b, mask_chunk, a=False, chunked=False)
+ del z_chunk_b
+
+ x_chunk = torch.matmul(a, b_chunk)
+ x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
+ x_chunk = self.layer_norm_out(x_chunk)
+ x_chunk = self.linear_z(x_chunk)
+
+ # The g dimension (col_dim) is parallel to and ahead of the
+ # overwrites in z. We can extract the g chunk normally.
+ z_chunk_g = slice_tensor(z, i, i + offset, col_dim)
+ g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g))
+ g_chunk.sigmoid_()
+ del z_chunk_g
+
+ x_chunk *= g_chunk
+
+ # Write the columns into z in-place
+ z_slicer = empty_slicer(z)
+ z_slicer[col_dim] = slice(i, i + offset)
+ if with_add:
+ z[z_slicer] += x_chunk
+ else:
+ z[z_slicer] = x_chunk
+ else:
+ b = compute_projection(z, mask, False, False)
+ x = torch.matmul(a, b)
+ x = self.layer_norm_out(x)
+ x = self.linear_z(x)
+ g = self.linear_g(z)
+ g.sigmoid_()
+ x *= g
+ if with_add:
+ z += x
+ else:
+ z = x
+
+ return z
+
+ def forward(
+ self,
+ z: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ inplace_safe: bool = False,
+ _add_with_inplace: bool = False,
+ _inplace_chunk_size: Optional[int] = 256,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ [*, N_res, N_res, C_z] input tensor
+ mask:
+ [*, N_res, N_res] input mask
+ Returns:
+ [*, N_res, N_res, C_z] output tensor
+ """
+ if inplace_safe:
+ x = self._inference_forward(
+ z,
+ mask,
+ inplace_chunk_size=_inplace_chunk_size,
+ with_add=_add_with_inplace,
+ )
+ return x
+
+ if mask is None:
+ mask = z.new_ones(z.shape[:-1])
+
+ mask = mask.unsqueeze(-1)
+
+ z = self.layer_norm_in(z)
+ a = mask
+ a = a * self.sigmoid(self.linear_a_g(z))
+ a = a * self.linear_a_p(z)
+ b = mask
+ b = b * self.sigmoid(self.linear_b_g(z))
+ b = b * self.linear_b_p(z)
+
+ device_type = a.device.type if a.device.type != "mps" else "cpu"
+ if is_fp16_enabled(device_type):
+ with torch.autocast(device_type=device_type, enabled=False):
+ x = self._combine_projections(a.float(), b.float())
+ else:
+ x = self._combine_projections(a, b)
+
+ del a, b
+ x = self.layer_norm_out(x)
+ x = self.linear_z(x)
+ g = self.sigmoid(self.linear_g(z))
+ x = x * g
+
+ return x
+
+
+class EsmFoldPreTrainedModel(EsmPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ # Subclass `EsMPreTrainedModel` to deal with special init
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, EsmFoldLinear):
+ with torch.no_grad():
+ if module.init_fn is not None:
+ module.init_fn(module.weight, module.bias)
+ elif module.init == "default":
+ trunc_normal_init_(module.weight, scale=1.0)
+ elif module.init == "relu":
+ trunc_normal_init_(module.weight, scale=2.0)
+ elif module.init == "glorot":
+ nn.init.xavier_uniform_(module.weight, gain=1)
+ elif module.init == "gating":
+ module.weight.fill_(0.0)
+ if module.bias:
+ module.bias.fill_(1.0)
+ elif module.init == "normal":
+ torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear")
+ elif module.init == "final":
+ module.weight.fill_(0.0)
+ elif isinstance(module, EsmFoldInvariantPointAttention):
+ ipa_point_weights_init_(module.head_weights)
+ elif isinstance(module, EsmFoldTriangularSelfAttentionBlock):
+ torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight)
+ torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias)
+ torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight)
+ torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias)
+ torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight)
+ torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias)
+ torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight)
+ torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias)
+
+ torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight)
+ torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias)
+ torch.nn.init.zeros_(module.pair_to_sequence.linear.weight)
+ torch.nn.init.zeros_(module.seq_attention.o_proj.weight)
+ torch.nn.init.zeros_(module.seq_attention.o_proj.bias)
+ torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight)
+ torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias)
+ torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight)
+ torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias)
+ else:
+ super()._init_weights(module)
+
+
+class EsmFoldSelfAttention(nn.Module):
+ def __init__(self, embed_dim, num_heads, head_width, gated=False):
+ super().__init__()
+ assert embed_dim == num_heads * head_width
+
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.head_width = head_width
+
+ self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
+ self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+ self.gated = gated
+ if gated:
+ self.g_proj = nn.Linear(embed_dim, embed_dim)
+ torch.nn.init.zeros_(self.g_proj.weight)
+ torch.nn.init.ones_(self.g_proj.bias)
+
+ self.rescale_factor = self.head_width**-0.5
+
+ torch.nn.init.zeros_(self.o_proj.bias)
+
+ def forward(self, x, mask=None, bias=None, indices=None):
+ """
+ Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths,
+ use mask.
+
+ Inputs:
+ x: batch of input sequences (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (..
+ x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads)
+
+ Outputs:
+ sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads)
+ """
+
+ t = self.proj(x).view(*x.shape[:2], self.num_heads, -1)
+ t = t.permute(0, 2, 1, 3)
+ q, k, v = t.chunk(3, dim=-1)
+
+ q = self.rescale_factor * q
+ a = torch.einsum("...qc,...kc->...qk", q, k)
+
+ # Add external attention bias.
+ if bias is not None:
+ a = a + bias.permute(0, 3, 1, 2)
+
+ # Do not attend to padding tokens.
+ if mask is not None:
+ mask = mask[:, None, None]
+ a = a.masked_fill(mask == False, -np.inf) # noqa: E712
+
+ a = nn.functional.softmax(a, dim=-1)
+
+ y = torch.einsum("...hqk,...hkc->...qhc", a, v)
+ y = y.reshape(*y.shape[:2], -1)
+
+ if self.gated:
+ y = self.g_proj(x).sigmoid() * y
+ y = self.o_proj(y)
+
+ return y, a.permute(0, 3, 1, 2)
+
+
+class EsmFoldDropout(nn.Module):
+ """
+ Implementation of dropout with the ability to share the dropout mask along a particular dimension.
+ """
+
+ def __init__(self, r: float, batch_dim: Union[int, list[int]]):
+ super().__init__()
+
+ self.r = r
+ if isinstance(batch_dim, int):
+ batch_dim = [batch_dim]
+ self.batch_dim = batch_dim
+ self.dropout = nn.Dropout(self.r)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shape = list(x.shape)
+ if self.batch_dim is not None:
+ for bd in self.batch_dim:
+ shape[bd] = 1
+ return x * self.dropout(x.new_ones(shape))
+
+
+class EsmFoldSequenceToPair(nn.Module):
+ def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
+ super().__init__()
+
+ self.layernorm = nn.LayerNorm(sequence_state_dim)
+ self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
+ self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)
+
+ torch.nn.init.zeros_(self.proj.bias)
+ torch.nn.init.zeros_(self.o_proj.bias)
+
+ def forward(self, sequence_state):
+ """
+ Inputs:
+ sequence_state: B x L x sequence_state_dim
+
+ Output:
+ pairwise_state: B x L x L x pairwise_state_dim
+
+ Intermediate state:
+ B x L x L x 2*inner_dim
+ """
+
+ assert len(sequence_state.shape) == 3
+
+ s = self.layernorm(sequence_state)
+ s = self.proj(s)
+ q, k = s.chunk(2, dim=-1)
+
+ prod = q[:, None, :, :] * k[:, :, None, :]
+ diff = q[:, None, :, :] - k[:, :, None, :]
+
+ x = torch.cat([prod, diff], dim=-1)
+ x = self.o_proj(x)
+
+ return x
+
+
+class EsmFoldPairToSequence(nn.Module):
+ def __init__(self, pairwise_state_dim, num_heads):
+ super().__init__()
+
+ self.layernorm = nn.LayerNorm(pairwise_state_dim)
+ self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)
+
+ def forward(self, pairwise_state):
+ """
+ Inputs:
+ pairwise_state: B x L x L x pairwise_state_dim
+
+ Output:
+ pairwise_bias: B x L x L x num_heads
+ """
+ assert len(pairwise_state.shape) == 4
+ z = self.layernorm(pairwise_state)
+ pairwise_bias = self.linear(z)
+ return pairwise_bias
+
+
+class EsmFoldResidueMLP(nn.Module):
+ def __init__(self, embed_dim, inner_dim, dropout=0):
+ super().__init__()
+
+ self.mlp = nn.Sequential(
+ nn.LayerNorm(embed_dim),
+ nn.Linear(embed_dim, inner_dim),
+ nn.ReLU(),
+ nn.Linear(inner_dim, embed_dim),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x):
+ return x + self.mlp(x)
+
+
+class EsmFoldTriangularSelfAttentionBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ sequence_state_dim = config.sequence_state_dim
+ pairwise_state_dim = config.pairwise_state_dim
+ sequence_num_heads = sequence_state_dim // config.sequence_head_width
+ pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width
+
+ self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
+
+ self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim)
+ self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads)
+
+ self.seq_attention = EsmFoldSelfAttention(
+ sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True
+ )
+ self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True)
+ self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False)
+
+ self.tri_att_start = EsmFoldTriangleAttention(
+ pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True
+ )
+ self.tri_att_end = EsmFoldTriangleAttention(
+ pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False
+ )
+
+ self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout)
+ self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout)
+
+ self.drop = nn.Dropout(config.dropout)
+ self.row_drop = EsmFoldDropout(config.dropout * 2, 2)
+ self.col_drop = EsmFoldDropout(config.dropout * 2, 1)
+
+ def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
+ """
+ Inputs:
+ sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean
+ tensor of valid positions
+
+ Output:
+ sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim
+ """
+ if len(sequence_state.shape) != 3:
+ raise ValueError(f"`sequence_state` should be a 3d-tensor, got {len(sequence_state.shape)} dims.")
+ if len(pairwise_state.shape) != 4:
+ raise ValueError(f"`pairwise_state` should be a 4d-tensor, got {len(pairwise_state.shape)} dims.")
+ if mask is not None and len(mask.shape) != 2:
+ raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
+
+ batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
+ pairwise_state_dim = pairwise_state.shape[3]
+
+ if sequence_state_dim != self.config.sequence_state_dim:
+ raise ValueError(
+ "`sequence_state` last dimension should be equal to `self.sequence_state_dim`. Got "
+ f"{sequence_state_dim} != {self.config.sequence_state_dim}."
+ )
+ if pairwise_state_dim != self.config.pairwise_state_dim:
+ raise ValueError(
+ "`pairwise_state` last dimension should be equal to `self.pairwise_state_dim`. Got "
+ f"{pairwise_state_dim} != {self.config.pairwise_state_dim}."
+ )
+ if batch_dim != pairwise_state.shape[0]:
+ raise ValueError(
+ f"`sequence_state` and `pairwise_state` have inconsistent batch size: {batch_dim} != "
+ f"{pairwise_state.shape[0]}."
+ )
+ if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]:
+ raise ValueError(
+ f"`sequence_state` and `pairwise_state` have inconsistent sequence length: {seq_dim} != "
+ f"{pairwise_state.shape[1]} or {pairwise_state.shape[2]}."
+ )
+
+ # Update sequence state
+ bias = self.pair_to_sequence(pairwise_state)
+
+ # Self attention with bias + mlp.
+ y = self.layernorm_1(sequence_state)
+ y, _ = self.seq_attention(y, mask=mask, bias=bias)
+ sequence_state = sequence_state + self.drop(y)
+ sequence_state = self.mlp_seq(sequence_state)
+
+ # Update pairwise state
+ pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)
+
+ # Axial attention with triangular bias.
+ tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None
+ pairwise_state = pairwise_state + self.row_drop(self.tri_mul_out(pairwise_state, mask=tri_mask))
+ pairwise_state = pairwise_state + self.col_drop(self.tri_mul_in(pairwise_state, mask=tri_mask))
+ pairwise_state = pairwise_state + self.row_drop(
+ self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
+ )
+ pairwise_state = pairwise_state + self.col_drop(
+ self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
+ )
+
+ # MLP over pairs.
+ pairwise_state = self.mlp_pair(pairwise_state)
+
+ return sequence_state, pairwise_state
+
+
+class EsmCategoricalMixture:
+ def __init__(self, param, bins=50, start=0, end=1):
+ # All tensors are of shape ..., bins.
+ self.logits = param
+ bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype)
+ self.v_bins = (bins[:-1] + bins[1:]) / 2
+
+ def log_prob(self, true):
+ # Shapes are:
+ # self.probs: ... x bins
+ # true : ...
+ true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1)
+ nll = self.logits.log_softmax(-1)
+ return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
+
+ def mean(self):
+ return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
+
+
+def categorical_lddt(logits, bins=50):
+ # Logits are ..., 37, bins.
+ return EsmCategoricalMixture(logits, bins=bins).mean()
+
+
+def get_axial_mask(mask):
+ """
+ Helper to convert B x L mask of valid positions to axial mask used in row column attentions.
+
+ Input:
+ mask: B x L tensor of booleans
+
+ Output:
+ mask: B x L x L tensor of booleans
+ """
+
+ if mask is None:
+ return None
+
+ if len(mask.shape) != 2:
+ raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
+ batch_dim, seq_dim = mask.shape
+ m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
+ m = m.reshape(batch_dim * seq_dim, seq_dim)
+ return m
+
+
+class EsmFoldRelativePosition(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.bins = config.position_bins
+
+ # Note an additional offset is used so that the 0th position
+ # is reserved for masked pairs.
+ self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim)
+
+ def forward(self, residue_index, mask=None):
+ """
+ Input:
+ residue_index: B x L tensor of indices (dtype=torch.long) mask: B x L tensor of booleans
+
+ Output:
+ pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
+ """
+ if residue_index.dtype != torch.long:
+ raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.")
+ if mask is not None and residue_index.shape != mask.shape:
+ raise ValueError(
+ f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}."
+ )
+
+ diff = residue_index[:, None, :] - residue_index[:, :, None]
+ diff = diff.clamp(-self.bins, self.bins)
+ diff = diff + self.bins + 1 # Add 1 to adjust for padding index.
+
+ if mask is not None:
+ mask = mask[:, None, :] * mask[:, :, None]
+ diff[mask == False] = 0 # noqa: E712
+
+ output = self.embedding(diff)
+ return output
+
+
+class EsmFoldAngleResnetBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu")
+ self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final")
+
+ self.relu = nn.ReLU()
+
+ def forward(self, a: torch.Tensor) -> torch.Tensor:
+ s_initial = a
+
+ a = self.relu(a)
+ a = self.linear_1(a)
+ a = self.relu(a)
+ a = self.linear_2(a)
+
+ return a + s_initial
+
+
+class EsmFoldAngleResnet(nn.Module):
+ """
+ Implements Algorithm 20, lines 11-14
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
+ self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
+
+ self.layers = nn.ModuleList()
+ for _ in range(config.num_resnet_blocks):
+ layer = EsmFoldAngleResnetBlock(config)
+ self.layers.append(layer)
+
+ self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2)
+
+ self.relu = nn.ReLU()
+
+ def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ s:
+ [*, C_hidden] single embedding
+ s_initial:
+ [*, C_hidden] single embedding as of the start of the StructureModule
+ Returns:
+ [*, no_angles, 2] predicted angles
+ """
+ # NOTE: The ReLU's applied to the inputs are absent from the supplement
+ # pseudocode but present in the source. For maximal compatibility with
+ # the pretrained weights, I'm going with the source.
+
+ # [*, C_hidden]
+ s_initial = self.relu(s_initial)
+ s_initial = self.linear_initial(s_initial)
+ s = self.relu(s)
+ s = self.linear_in(s)
+ s = s + s_initial
+
+ for l in self.layers:
+ s = l(s)
+
+ s = self.relu(s)
+
+ # [*, no_angles * 2]
+ s = self.linear_out(s)
+
+ # [*, no_angles, 2]
+ s = s.view(s.shape[:-1] + (-1, 2))
+
+ unnormalized_s = s
+ norm_denom = torch.sqrt(
+ torch.clamp(
+ torch.sum(s**2, dim=-1, keepdim=True),
+ min=self.config.epsilon,
+ )
+ )
+ s = s / norm_denom
+
+ return unnormalized_s, s
+
+
+class EsmFoldInvariantPointAttention(nn.Module):
+ """
+ Implements Algorithm 22.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ c_s = config.sequence_dim
+ c_z = config.pairwise_dim
+ self.hidden_dim = config.ipa_dim
+ self.num_heads = config.num_heads_ipa
+ self.num_qk_points = config.num_qk_points
+ self.num_v_points = config.num_v_points
+
+ # These linear layers differ from their specifications in the
+ # supplement. There, they lack bias and use Glorot initialization.
+ # Here as in the official source, they have bias and use the default
+ # Lecun initialization.
+ hc = config.ipa_dim * config.num_heads_ipa
+ self.linear_q = EsmFoldLinear(c_s, hc)
+ self.linear_kv = EsmFoldLinear(c_s, 2 * hc)
+
+ hpq = config.num_heads_ipa * config.num_qk_points * 3
+ self.linear_q_points = EsmFoldLinear(c_s, hpq)
+
+ hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3
+ self.linear_kv_points = EsmFoldLinear(c_s, hpkv)
+
+ self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa)
+
+ self.head_weights = nn.Parameter(torch.zeros(config.num_heads_ipa))
+
+ concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4)
+ self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final")
+
+ self.softmax = nn.Softmax(dim=-1)
+ self.softplus = nn.Softplus()
+
+ def forward(
+ self,
+ s: torch.Tensor,
+ z: Optional[torch.Tensor],
+ r: Rigid,
+ mask: torch.Tensor,
+ _offload_inference: bool = False,
+ _z_reference_list: Optional[Sequence[torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ s:
+ [*, N_res, C_s] single representation
+ z:
+ [*, N_res, N_res, C_z] pair representation
+ r:
+ [*, N_res] transformation object
+ mask:
+ [*, N_res] mask
+ Returns:
+ [*, N_res, C_s] single representation update
+ """
+ z = [z]
+
+ #######################################
+ # Generate scalar and point activations
+ #######################################
+ # [*, N_res, H * C_hidden]
+ q = self.linear_q(s)
+ kv = self.linear_kv(s)
+
+ # [*, N_res, H, C_hidden]
+ q = q.view(q.shape[:-1] + (self.num_heads, -1))
+
+ # [*, N_res, H, 2 * C_hidden]
+ kv = kv.view(kv.shape[:-1] + (self.num_heads, -1))
+
+ # [*, N_res, H, C_hidden]
+ k, v = torch.split(kv, self.hidden_dim, dim=-1)
+
+ # [*, N_res, H * P_q * 3]
+ q_pts = self.linear_q_points(s)
+
+ # This is kind of clunky, but it's how the original does it
+ # [*, N_res, H * P_q, 3]
+ q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
+ q_pts = torch.stack(q_pts, dim=-1)
+ q_pts = r[..., None].apply(q_pts)
+
+ # [*, N_res, H, P_q, 3]
+ q_pts = q_pts.view(q_pts.shape[:-2] + (self.num_heads, self.num_qk_points, 3))
+
+ # [*, N_res, H * (P_q + P_v) * 3]
+ kv_pts = self.linear_kv_points(s)
+
+ # [*, N_res, H * (P_q + P_v), 3]
+ kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
+ kv_pts = torch.stack(kv_pts, dim=-1)
+ kv_pts = r[..., None].apply(kv_pts)
+
+ # [*, N_res, H, (P_q + P_v), 3]
+ kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3))
+
+ # [*, N_res, H, P_q/P_v, 3]
+ k_pts, v_pts = torch.split(kv_pts, [self.num_qk_points, self.num_v_points], dim=-2)
+
+ ##########################
+ # Compute attention scores
+ ##########################
+ # [*, N_res, N_res, H]
+ b = self.linear_b(z[0])
+
+ if _offload_inference:
+ assert sys.getrefcount(z[0]) == 2
+ z[0] = z[0].cpu()
+
+ # [*, H, N_res, N_res]
+ device_type = q.device.type if q.device.type != "mps" else "cpu"
+ if is_fp16_enabled(device_type):
+ with torch.autocast(device_type=device_type, enabled=False):
+ a = torch.matmul(
+ permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
+ permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
+ )
+ else:
+ a = torch.matmul(
+ permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
+ permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
+ )
+
+ a *= math.sqrt(1.0 / (3 * self.hidden_dim))
+ a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))
+
+ # [*, N_res, N_res, H, P_q, 3]
+ pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
+ pt_att = pt_att**2
+
+ # [*, N_res, N_res, H, P_q]
+ pt_att = sum(torch.unbind(pt_att, dim=-1))
+ head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1)))
+ head_weights = head_weights * math.sqrt(1.0 / (3 * (self.num_qk_points * 9.0 / 2)))
+ pt_att = pt_att * head_weights
+
+ # [*, N_res, N_res, H]
+ pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
+ # [*, N_res, N_res]
+ square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
+ square_mask = self.config.inf * (square_mask - 1)
+
+ # [*, H, N_res, N_res]
+ pt_att = permute_final_dims(pt_att, (2, 0, 1))
+
+ a = a + pt_att
+ a = a + square_mask.unsqueeze(-3)
+ a = self.softmax(a)
+
+ ################
+ # Compute output
+ ################
+ # [*, N_res, H, C_hidden]
+ o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3)
+
+ # [*, N_res, H * C_hidden]
+ o = flatten_final_dims(o, 2)
+
+ # [*, H, 3, N_res, P_v]
+ o_pt = torch.sum(
+ (a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]),
+ dim=-2,
+ )
+
+ # [*, N_res, H, P_v, 3]
+ o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
+ o_pt = r[..., None, None].invert_apply(o_pt)
+
+ # [*, N_res, H * P_v]
+ o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.config.epsilon), 2)
+
+ # [*, N_res, H * P_v, 3]
+ o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
+
+ if _offload_inference:
+ z[0] = z[0].to(o_pt.device)
+
+ # [*, N_res, H, C_z]
+ o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
+
+ # [*, N_res, H * C_z]
+ o_pair = flatten_final_dims(o_pair, 2)
+
+ # [*, N_res, C_s]
+ s = self.linear_out(
+ torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype)
+ )
+
+ return s
+
+
+class EsmFoldBackboneUpdate(nn.Module):
+ """
+ Implements part of Algorithm 23.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final")
+
+ def forward(self, s: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ [*, N_res, C_s] single representation
+ Returns:
+ [*, N_res, 6] update vector
+ """
+ # [*, 6]
+ update = self.linear(s)
+
+ return update
+
+
+class EsmFoldStructureModuleTransitionLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
+ self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
+ self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final")
+
+ self.relu = nn.ReLU()
+
+ def forward(self, s):
+ s_initial = s
+ s = self.linear_1(s)
+ s = self.relu(s)
+ s = self.linear_2(s)
+ s = self.relu(s)
+ s = self.linear_3(s)
+
+ s = s + s_initial
+
+ return s
+
+
+class EsmFoldStructureModuleTransition(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ self.layers = nn.ModuleList()
+ for _ in range(config.num_transition_layers):
+ l = EsmFoldStructureModuleTransitionLayer(config)
+ self.layers.append(l)
+
+ self.dropout = nn.Dropout(config.dropout_rate)
+ self.layer_norm = LayerNorm(config.sequence_dim)
+
+ def forward(self, s):
+ for l in self.layers:
+ s = l(s)
+
+ s = self.dropout(s)
+ s = self.layer_norm(s)
+
+ return s
+
+
+class EsmFoldStructureModule(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ # Buffers to be lazily initialized later
+ # self.default_frames
+ # self.group_idx
+ # self.atom_mask
+ # self.lit_positions
+
+ self.layer_norm_s = LayerNorm(config.sequence_dim)
+ self.layer_norm_z = LayerNorm(config.pairwise_dim)
+
+ self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim)
+
+ self.ipa = EsmFoldInvariantPointAttention(config)
+
+ self.ipa_dropout = nn.Dropout(config.dropout_rate)
+ self.layer_norm_ipa = LayerNorm(config.sequence_dim)
+
+ self.transition = EsmFoldStructureModuleTransition(config)
+ self.bb_update = EsmFoldBackboneUpdate(config)
+ self.angle_resnet = EsmFoldAngleResnet(config)
+
+ def forward(
+ self,
+ evoformer_output_dict,
+ aatype,
+ mask=None,
+ _offload_inference=False,
+ ):
+ """
+ Args:
+ evoformer_output_dict:
+ Dictionary containing:
+ "single":
+ [*, N_res, C_s] single representation
+ "pair":
+ [*, N_res, N_res, C_z] pair representation
+ aatype:
+ [*, N_res] amino acid indices
+ mask:
+ Optional [*, N_res] sequence mask
+ Returns:
+ A dictionary of outputs
+ """
+ s = evoformer_output_dict["single"]
+
+ if mask is None:
+ # [*, N]
+ mask = s.new_ones(s.shape[:-1])
+
+ # [*, N, C_s]
+ s = self.layer_norm_s(s)
+
+ # [*, N, N, C_z]
+ z = self.layer_norm_z(evoformer_output_dict["pair"])
+
+ z_reference_list = None
+ if _offload_inference:
+ assert sys.getrefcount(evoformer_output_dict["pair"]) == 2
+ evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
+ z_reference_list = [z]
+ z = None
+
+ # [*, N, C_s]
+ s_initial = s
+ s = self.linear_in(s)
+
+ # [*, N]
+ rigids = Rigid.identity(
+ s.shape[:-1],
+ s.dtype,
+ s.device,
+ self.training,
+ fmt="quat",
+ )
+ outputs = []
+ for i in range(self.config.num_blocks):
+ # [*, N, C_s]
+ s = s + self.ipa(
+ s,
+ z,
+ rigids,
+ mask,
+ _offload_inference=_offload_inference,
+ _z_reference_list=z_reference_list,
+ )
+ s = self.ipa_dropout(s)
+ s = self.layer_norm_ipa(s)
+ s = self.transition(s)
+
+ # [*, N]
+ rigids = rigids.compose_q_update_vec(self.bb_update(s))
+
+ # To hew as closely as possible to AlphaFold, we convert our
+ # quaternion-based transformations to rotation-matrix ones
+ # here
+ backb_to_global = Rigid(
+ Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None),
+ rigids.get_trans(),
+ )
+
+ backb_to_global = backb_to_global.scale_translation(self.config.trans_scale_factor)
+
+ # [*, N, 7, 2]
+ unnormalized_angles, angles = self.angle_resnet(s, s_initial)
+
+ all_frames_to_global = self.torsion_angles_to_frames(backb_to_global, angles, aatype)
+
+ pred_xyz = self.frames_and_literature_positions_to_atom14_pos(all_frames_to_global, aatype)
+
+ scaled_rigids = rigids.scale_translation(self.config.trans_scale_factor)
+
+ preds = {
+ "frames": scaled_rigids.to_tensor_7(),
+ "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
+ "unnormalized_angles": unnormalized_angles,
+ "angles": angles,
+ "positions": pred_xyz,
+ "states": s,
+ }
+
+ outputs.append(preds)
+
+ rigids = rigids.stop_rot_gradient()
+
+ del z, z_reference_list
+
+ if _offload_inference:
+ evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device)
+
+ outputs = dict_multimap(torch.stack, outputs)
+ outputs["single"] = s
+
+ return outputs
+
+ def _init_residue_constants(self, float_dtype, device):
+ if not hasattr(self, "default_frames"):
+ self.register_buffer(
+ "default_frames",
+ torch.tensor(
+ residue_constants.restype_rigid_group_default_frame,
+ dtype=float_dtype,
+ device=device,
+ requires_grad=False,
+ ),
+ persistent=False,
+ )
+ if not hasattr(self, "group_idx"):
+ self.register_buffer(
+ "group_idx",
+ torch.tensor(
+ residue_constants.restype_atom14_to_rigid_group,
+ device=device,
+ requires_grad=False,
+ ),
+ persistent=False,
+ )
+ if not hasattr(self, "atom_mask"):
+ self.register_buffer(
+ "atom_mask",
+ torch.tensor(
+ residue_constants.restype_atom14_mask,
+ dtype=float_dtype,
+ device=device,
+ requires_grad=False,
+ ),
+ persistent=False,
+ )
+ if not hasattr(self, "lit_positions"):
+ self.register_buffer(
+ "lit_positions",
+ torch.tensor(
+ residue_constants.restype_atom14_rigid_group_positions,
+ dtype=float_dtype,
+ device=device,
+ requires_grad=False,
+ ),
+ persistent=False,
+ )
+
+ def torsion_angles_to_frames(self, r, alpha, f):
+ # Lazily initialize the residue constants on the correct device
+ self._init_residue_constants(alpha.dtype, alpha.device)
+ # Separated purely to make testing less annoying
+ return torsion_angles_to_frames(r, alpha, f, self.default_frames)
+
+ def frames_and_literature_positions_to_atom14_pos(self, r, f): # [*, N, 8] # [*, N]
+ # Lazily initialize the residue constants on the correct device
+ self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
+ return frames_and_literature_positions_to_atom14_pos(
+ r,
+ f,
+ self.default_frames,
+ self.group_idx,
+ self.atom_mask,
+ self.lit_positions,
+ )
+
+
+class EsmFoldingTrunk(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ c_s = config.sequence_state_dim
+ c_z = config.pairwise_state_dim
+
+ self.pairwise_positional_embedding = EsmFoldRelativePosition(config)
+
+ self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)])
+
+ self.recycle_bins = 15
+ self.recycle_s_norm = nn.LayerNorm(c_s)
+ self.recycle_z_norm = nn.LayerNorm(c_z)
+ self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
+ self.recycle_disto.weight[0].detach().zero_()
+
+ self.structure_module = EsmFoldStructureModule(config.structure_module)
+ self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim)
+ self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim)
+
+ self.chunk_size = config.chunk_size
+
+ def set_chunk_size(self, chunk_size):
+ # This parameter means the axial attention will be computed
+ # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
+ # It's equivalent to running a for loop over chunks of the dimension we're iterative over,
+ # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-length chunks.
+ self.chunk_size = chunk_size
+
+ def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles):
+ """
+ Inputs:
+ seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B
+ x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues
+
+ Output:
+ predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
+ """
+
+ device = seq_feats.device
+ s_s_0 = seq_feats
+ s_z_0 = pair_feats
+
+ if no_recycles is None:
+ no_recycles = self.config.max_recycles
+ else:
+ if no_recycles < 0:
+ raise ValueError("Number of recycles must not be negative.")
+ no_recycles += 1 # First 'recycle' is just the standard forward pass through the model.
+
+ def trunk_iter(s, z, residx, mask):
+ z = z + self.pairwise_positional_embedding(residx, mask=mask)
+
+ for block in self.blocks:
+ s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
+ return s, z
+
+ s_s = s_s_0
+ s_z = s_z_0
+ recycle_s = torch.zeros_like(s_s)
+ recycle_z = torch.zeros_like(s_z)
+ recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)
+
+ for recycle_idx in range(no_recycles):
+ with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
+ # === Recycling ===
+ recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
+ recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
+ recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)
+
+ s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
+
+ # === Structure module ===
+ structure = self.structure_module(
+ {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
+ true_aa,
+ mask.float(),
+ )
+
+ recycle_s = s_s
+ recycle_z = s_z
+ # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
+ recycle_bins = EsmFoldingTrunk.distogram(
+ structure["positions"][-1][:, :, :3],
+ 3.375,
+ 21.375,
+ self.recycle_bins,
+ )
+
+ structure["s_s"] = s_s
+ structure["s_z"] = s_z
+
+ return structure
+
+ @staticmethod
+ def distogram(coords, min_bin, max_bin, num_bins):
+ # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates.
+ boundaries = torch.linspace(
+ min_bin,
+ max_bin,
+ num_bins - 1,
+ device=coords.device,
+ )
+ boundaries = boundaries**2
+ N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
+ # Infer CB coordinates.
+ b = CA - N
+ c = C - CA
+ a = b.cross(c, dim=-1)
+ CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
+ dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
+ bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L]
+ return bins
+
+
+# TODO Add information to the docstring about any methods that convert to PDB format, or otherwise prepare
+# the outputs for downstream use.
+
+
+@auto_docstring(
+ custom_intro="""
+ ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed
+ by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to
+ the rest of the model combined! It outputs a dictionary containing predicted structural information about the input
+ protein(s).
+ """
+)
+class EsmForProteinFolding(EsmPreTrainedModel):
+ _no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
+ _supports_flash_attn = False
+ _supports_sdpa = False
+ _supports_attention_backend = False
+
+ _can_record_outputs = None
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.config = config
+
+ self.distogram_bins = 64
+
+ self.esm = EsmModel(config, add_pooling_layer=False)
+
+ self.esm.requires_grad_(False)
+ if self.config.esmfold_config.fp16_esm:
+ self.esm.half()
+
+ self.esm_feats = self.config.hidden_size
+ self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads
+ self.esm_layers = self.config.num_hidden_layers
+ self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list))
+ self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1))
+
+ trunk_config = self.config.esmfold_config.trunk
+ c_s = trunk_config.sequence_state_dim
+ c_z = trunk_config.pairwise_state_dim
+ self.esm_s_mlp = nn.Sequential(
+ LayerNorm(self.esm_feats),
+ nn.Linear(self.esm_feats, c_s),
+ nn.ReLU(),
+ nn.Linear(c_s, c_s),
+ )
+
+ # 0 is padding, N is unknown residues, N + 1 is mask.
+ self.n_tokens_embed = residue_constants.restype_num + 3
+ self.pad_idx = 0
+ self.unk_idx = self.n_tokens_embed - 2
+ self.mask_idx = self.n_tokens_embed - 1
+ self.esm_dict_cls_idx = self.config.vocab_list.index("")
+ self.esm_dict_mask_idx = self.config.vocab_list.index("")
+ self.esm_dict_eos_idx = self.config.vocab_list.index("")
+ self.esm_dict_padding_idx = self.config.vocab_list.index("")
+ if self.config.esmfold_config.embed_aa:
+ self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)
+
+ self.trunk = EsmFoldingTrunk(trunk_config)
+
+ self.distogram_head = nn.Linear(c_z, self.distogram_bins)
+ self.ptm_head = nn.Linear(c_z, self.distogram_bins)
+ self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
+ self.lddt_bins = 50
+ structure_module_config = trunk_config.structure_module
+ self.lddt_head = nn.Sequential(
+ nn.LayerNorm(structure_module_config.sequence_dim),
+ nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim),
+ nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim),
+ nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins),
+ )
+
+ @staticmethod
+ def _af2_to_esm_from_vocab_list(vocab_list: list[str]) -> torch.Tensor:
+ # Remember that t is shifted from residue_constants by 1 (0 is padding).
+ esm_reorder = [vocab_list.index("")] + [vocab_list.index(v) for v in residue_constants.restypes_with_x]
+ return torch.tensor(esm_reorder)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ masking_pattern: Optional[torch.Tensor] = None,
+ num_recycles: Optional[int] = None,
+ output_hidden_states: Optional[bool] = False,
+ ) -> EsmForProteinFoldingOutput:
+ r"""
+ masking_pattern (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`.
+ num_recycles (`int`, *optional*, defaults to `None`):
+ Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling"
+ consists of passing the output of the folding trunk back in as input to the trunk. During training, the
+ number of recycles should vary with each batch, to ensure that the model learns to output valid predictions
+ after each recycle. During inference, num_recycles should be set to the highest value that the model was
+ trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is
+ used.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, EsmForProteinFolding
+
+ >>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
+ >>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False) # A tiny random peptide
+ >>> outputs = model(**inputs)
+ >>> folded_positions = outputs.positions
+ ```
+
+ """
+ cfg = self.config.esmfold_config
+
+ aa = input_ids # B x L
+ B = aa.shape[0]
+ L = aa.shape[1]
+ device = input_ids.device
+ if attention_mask is None:
+ attention_mask = torch.ones_like(aa, device=device)
+ if position_ids is None:
+ position_ids = torch.arange(L, device=device).expand_as(input_ids)
+
+ # === ESM ===
+ esmaa = self.af2_idx_to_esm_idx(aa, attention_mask)
+
+ if masking_pattern is not None:
+ masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern)
+ else:
+ masked_aa = aa
+ mlm_targets = None
+
+ # We get sequence and pair representations from whatever version of ESM /
+ # configuration we are using. The sequence representation esm_s is always
+ # present. The pair embedding esm_z may be present depending on the
+ # configuration of the model. If esm_z is not used by the model then it
+ # is returned as None here.
+ esm_s = self.compute_language_model_representations(esmaa)
+
+ # Convert esm_s and esm_z, if present, to the precision used by the trunk and
+ # the structure module. These tensors may be a lower precision if, for example,
+ # we're running the language model in fp16 precision.
+ esm_s = esm_s.to(self.esm_s_combine.dtype)
+
+ if cfg.esm_ablate_sequence:
+ esm_s = esm_s * 0
+
+ esm_s = esm_s.detach()
+
+ # === preprocessing ===
+ esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)
+ s_s_0 = self.esm_s_mlp(esm_s)
+
+ s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim)
+
+ if self.config.esmfold_config.embed_aa:
+ s_s_0 += self.embedding(masked_aa)
+
+ structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles)
+ # Documenting what we expect:
+ structure = {
+ k: v
+ for k, v in structure.items()
+ if k
+ in [
+ "s_z",
+ "s_s",
+ "frames",
+ "sidechain_frames",
+ "unnormalized_angles",
+ "angles",
+ "positions",
+ "states",
+ ]
+ }
+
+ # Add BERT mask for the loss to use, if available.
+ if mlm_targets:
+ structure["mlm_targets"] = mlm_targets
+
+ disto_logits = self.distogram_head(structure["s_z"])
+ disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
+ structure["distogram_logits"] = disto_logits
+
+ lm_logits = self.lm_head(structure["s_s"])
+ structure["lm_logits"] = lm_logits
+
+ structure["aatype"] = aa
+ make_atom14_masks(structure)
+ # Of course, this doesn't respect the true mask because it doesn't know about it...
+ # We're not going to properly mask change of index tensors:
+ # "residx_atom14_to_atom37",
+ # "residx_atom37_to_atom14",
+ for k in [
+ "atom14_atom_exists",
+ "atom37_atom_exists",
+ ]:
+ structure[k] *= attention_mask.unsqueeze(-1)
+ structure["residue_index"] = position_ids
+
+ lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins)
+ structure["lddt_head"] = lddt_head
+ plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins)
+ structure["plddt"] = plddt
+
+ ptm_logits = self.ptm_head(structure["s_z"])
+ structure["ptm_logits"] = ptm_logits
+ structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins)
+ structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins))
+
+ return EsmForProteinFoldingOutput(**structure)
+
+ def af2_idx_to_esm_idx(self, aa, mask):
+ # avoid indexing on different devices
+ if self.af2_to_esm.device != aa.device:
+ self.af2_to_esm = self.af2_to_esm.to(aa.device)
+ aa = (aa + 1).masked_fill(mask != 1, 0)
+ return self.af2_to_esm[aa]
+
+ def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor:
+ device = next(self.parameters()).device
+ B, L = esmaa.shape # B = batch size, L = sequence length.
+
+ if self.config.esmfold_config.bypass_lm:
+ esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device)
+ return esm_s
+
+ bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx
+ bos = esmaa.new_full((B, 1), bosi)
+ eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx)
+ esmaa = torch.cat([bos, esmaa, eos], dim=1)
+ # Use the first padding index as eos during inference.
+ esmaa[range(B), (esmaa != 1).sum(1)] = eosi
+
+ # _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map)
+ # Because we do not support use_esm_attn_map in the HF port as it is not used in any public models,
+ # esm_z is always None
+ esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"]
+ esm_s = torch.stack(esm_hidden_states, dim=2)
+
+ esm_s = esm_s[:, 1:-1] # B, L, nLayers, C
+
+ return esm_s
+
+ def bert_mask(self, aa, esmaa, mask, pattern):
+ new_aa = aa.clone()
+ target = aa.clone()
+ new_esmaa = esmaa.clone()
+ new_aa[pattern == 1] = self.mask_idx
+ target[pattern != 1] = 0
+ new_esmaa[pattern == 1] = self.esm_dict_mask_idx
+ return new_aa, new_esmaa, target
+
+ @torch.no_grad()
+ def infer(
+ self,
+ seqs: Union[str, list[str]],
+ position_ids=None,
+ ):
+ if isinstance(seqs, str):
+ lst = [seqs]
+ else:
+ lst = seqs
+ # Returns the raw outputs of the model given an input sequence.
+ device = next(self.parameters()).device
+ aatype = collate_dense_tensors(
+ [
+ torch.from_numpy(
+ residue_constants.sequence_to_onehot(
+ sequence=seq,
+ mapping=residue_constants.restype_order_with_x,
+ map_unknown_to_x=True,
+ )
+ )
+ .to(device)
+ .argmax(dim=1)
+ for seq in lst
+ ]
+ ) # B=1 x L
+ mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
+ position_ids = (
+ torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
+ if position_ids is None
+ else position_ids.to(device)
+ )
+ if position_ids.ndim == 1:
+ position_ids = position_ids.unsqueeze(0)
+ return self.forward(
+ aatype,
+ mask,
+ position_ids=position_ids,
+ )
+
+ @staticmethod
+ def output_to_pdb(output: dict) -> list[str]:
+ """Returns the pbd (file) string from the model given the model output."""
+ output = {k: v.to("cpu").numpy() for k, v in output.items()}
+ pdbs = []
+ final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
+ final_atom_mask = output["atom37_atom_exists"]
+ for i in range(output["aatype"].shape[0]):
+ aa = output["aatype"][i]
+ pred_pos = final_atom_positions[i]
+ mask = final_atom_mask[i]
+ resid = output["residue_index"][i] + 1
+ pred = OFProtein(
+ aatype=aa,
+ atom_positions=pred_pos,
+ atom_mask=mask,
+ residue_index=resid,
+ b_factors=output["plddt"][i],
+ )
+ pdbs.append(to_pdb(pred))
+ return pdbs
+
+ def infer_pdb(self, seqs, *args, **kwargs) -> str:
+ """Returns the pdb (file) string from the model given an input sequence."""
+ assert isinstance(seqs, str)
+ output = self.infer(seqs, *args, **kwargs)
+ return self.output_to_pdb(output)[0]
+
+ def infer_pdbs(self, seqs: list[str], *args, **kwargs) -> list[str]:
+ """Returns the pdb (file) string from the model given an input sequence."""
+ output = self.infer(seqs, *args, **kwargs)
+ return self.output_to_pdb(output)
+
+
+__all__ = ["EsmForProteinFolding", "EsmFoldPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_tf_esm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_tf_esm.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fd066868f0e86cd2e130ad170eb38c4791f4f88
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_tf_esm.py
@@ -0,0 +1,1574 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ESM model."""
+
+from __future__ import annotations
+
+import os
+
+import numpy as np
+import tensorflow as tf
+
+from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
+from ...modeling_tf_outputs import (
+ TFBaseModelOutputWithPastAndCrossAttentions,
+ TFBaseModelOutputWithPoolingAndCrossAttentions,
+ TFMaskedLMOutput,
+ TFSequenceClassifierOutput,
+ TFTokenClassifierOutput,
+)
+from ...modeling_tf_utils import (
+ TFMaskedLanguageModelingLoss,
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ TFTokenClassificationLoss,
+ get_initializer,
+ keras,
+ shape_list,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, stable_softmax
+from ...utils import logging
+from .configuration_esm import EsmConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/esm2_t6_8M_UR50D"
+_CONFIG_FOR_DOC = "EsmConfig"
+
+
+def rotate_half(x):
+ x1, x2 = tf.split(x, 2, axis=-1)
+ return tf.concat((-x2, x1), axis=-1)
+
+
+def apply_rotary_pos_emb(x, cos, sin):
+ cos = cos[:, :, : tf.shape(x)[-2], :]
+ sin = sin[:, :, : tf.shape(x)[-2], :]
+
+ return (x * cos) + (rotate_half(x) * sin)
+
+
+def symmetrize(x):
+ "Make layer symmetric in final two dimensions, used for contact prediction."
+ return x + tf.linalg.matrix_transpose(x) # Transposes last two dimensions only
+
+
+def average_product_correct(x):
+ "Perform average product correct, used for contact prediction."
+ a1 = tf.reduce_sum(x, -1, keepdims=True)
+ a2 = tf.reduce_sum(x, -2, keepdims=True)
+ a12 = tf.reduce_sum(x, (-1, -2), keepdims=True)
+
+ avg = a1 * a2
+ avg = avg / a12
+ normalized = x - avg
+ return normalized
+
+
+class TFRotaryEmbedding(keras.layers.Layer):
+ """
+ Rotary position embeddings based on those in
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
+ matrices which depend on their relative positions.
+ """
+
+ def __init__(self, dim: int, name=None):
+ super().__init__(name=name)
+ # Matt: The PyTorch version of this layer does a lot of work to cache values, but we just rely on TF compilation
+ # and/or XLA to sort out constants like that. It actually may not seem like this layer needs to be stateful at
+ # all when we benefit from TF compilation, but it does. The reason is that self.inv_freq is a buffer in the
+ # original implementation, but all the shared ESM checkpoints were trained with fp16 params. This means that
+ # the inv_freq tensor was stored as a float16, and we need to replicate those lower-precision values or our
+ # models give different outputs from the original.
+ self.dim = dim
+
+ def build(self, input_shape):
+ super().build(input_shape)
+ self.inv_freq = self.add_weight(
+ "inv_freq", shape=(self.dim // 2,), dtype=tf.float32, initializer=get_initializer(1.0), trainable=False
+ )
+ self.inv_freq.assign(
+ 1.0 / (10000 ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim))
+ )
+
+ def _compute_cos_sin(self, x, seq_dimension=2):
+ seq_len = tf.shape(x)[seq_dimension]
+
+ t = tf.range(seq_len, dtype=self.inv_freq.dtype)
+ freqs = tf.einsum("i, j -> ij", t, self.inv_freq) # Outer multiplication
+ emb = tf.concat((freqs, freqs), axis=-1)[None, None, :, :]
+
+ return tf.cos(emb), tf.sin(emb)
+
+ def call(self, q: tf.Tensor, k: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
+ cos_emb, sin_emb = self._compute_cos_sin(k, seq_dimension=-2)
+
+ return (
+ apply_rotary_pos_emb(q, cos_emb, sin_emb),
+ apply_rotary_pos_emb(k, cos_emb, sin_emb),
+ )
+
+
+class TFEsmContactPredictionHead(keras.layers.Layer):
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
+
+ def __init__(
+ self,
+ in_features: int,
+ bias=True,
+ eos_idx: int = 2,
+ name=None,
+ ):
+ super().__init__(name=name)
+ self.eos_idx = eos_idx
+ self.in_features = in_features
+ self.regression = keras.layers.Dense(1, use_bias=bias, activation="sigmoid", name="regression")
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "regression", None) is not None:
+ with tf.name_scope(self.regression.name):
+ self.regression.build((None, self.in_features))
+
+ def call(self, tokens, attentions):
+ # remove eos token attentions
+ eos_mask = tf.cast(tokens != self.eos_idx, attentions.dtype)
+ eos_mask = tf.expand_dims(eos_mask, 1) * tf.expand_dims(eos_mask, 2)
+ attentions = attentions * eos_mask[:, None, None, :, :]
+ attentions = attentions[..., :-1, :-1]
+ # remove cls token attentions
+ attentions = attentions[..., 1:, 1:]
+ batch_size, layers, heads, seqlen, _ = shape_list(attentions)
+ attentions = tf.reshape(attentions, (batch_size, layers * heads, seqlen, seqlen))
+
+ # features: batch x channels x tokens x tokens (symmetric)
+ attentions = average_product_correct(symmetrize(attentions))
+ attentions = tf.transpose(attentions, perm=(0, 2, 3, 1))
+ return tf.squeeze(self.regression(attentions), 3)
+
+
+class TFEsmEmbeddings(keras.layers.Layer):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+ """
+
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.word_embeddings = keras.layers.Embedding(
+ config.vocab_size,
+ config.hidden_size,
+ embeddings_initializer=get_initializer(config.initializer_range),
+ name="word_embeddings",
+ )
+ self.position_embeddings = keras.layers.Embedding(
+ config.max_position_embeddings,
+ config.hidden_size,
+ embeddings_initializer=get_initializer(config.initializer_range),
+ name="position_embeddings",
+ )
+
+ if config.emb_layer_norm_before:
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+ else:
+ self.layer_norm = None
+ # Matt: I think this line was copied incorrectly from BERT, disabling for now
+ # self.dropout = Dropout(config.hidden_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.position_ids = tf.range(config.max_position_embeddings)[None, :]
+
+ self.padding_idx = config.pad_token_id
+ self.token_dropout = config.token_dropout
+ self.mask_token_id = config.mask_token_id
+ self.config = config
+
+ def call(
+ self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
+ else:
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+ if inputs_embeds is None:
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an
+ # embedding_scale factor here.
+ embeddings = inputs_embeds
+
+ # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
+ # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
+ # masked tokens are treated as if they were selected for input dropout and zeroed out.
+ # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
+ # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
+ # This is analogous to the way that dropout layers scale down outputs during evaluation when not
+ # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
+ if self.token_dropout:
+ embeddings = tf.where((input_ids == self.mask_token_id)[:, :, None], 0.0, embeddings)
+ mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
+ src_lengths = tf.cast(tf.reduce_sum(attention_mask, axis=-1), tf.float32)
+ masked_tokens = input_ids == self.mask_token_id
+ mask_ratio_observed = tf.math.count_nonzero(masked_tokens, dtype=tf.float32, axis=-1) / src_lengths
+ embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
+
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+
+ if self.layer_norm is not None:
+ embeddings = self.layer_norm(embeddings)
+ if attention_mask is not None:
+ embeddings = embeddings * tf.cast(tf.expand_dims(attention_mask, -1), embeddings.dtype)
+ # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
+ # embeddings = self.dropout(embeddings)
+ return embeddings
+
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+ """
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+
+ Args:
+ inputs_embeds: tf.Tensor
+
+ Returns: tf.Tensor
+ """
+ input_shape = shape_list(inputs_embeds)[:-1]
+ sequence_length = input_shape[1]
+
+ position_ids = tf.range(
+ start=self.padding_idx + 1, limit=sequence_length + self.padding_idx + 1, dtype=tf.int64
+ )
+ return tf.broadcast_to(tf.expand_dims(position_ids, 0), input_shape)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "word_embeddings", None) is not None:
+ with tf.name_scope(self.word_embeddings.name):
+ self.word_embeddings.build(None)
+ if getattr(self, "position_embeddings", None) is not None:
+ with tf.name_scope(self.position_embeddings.name):
+ self.position_embeddings.build(None)
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.hidden_size])
+
+
+class TFEsmSelfAttention(keras.layers.Layer):
+ def __init__(self, config, position_embedding_type=None, name=None):
+ super().__init__(name=name)
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = keras.layers.Dense(
+ self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
+ )
+ self.key = keras.layers.Dense(
+ self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
+ )
+ self.value = keras.layers.Dense(
+ self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
+ )
+
+ self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = position_embedding_type or getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ self.rotary_embeddings = None
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = keras.layers.Embedding(
+ 2 * config.max_position_embeddings - 1,
+ self.attention_head_size,
+ embeddings_initializer=get_initializer(config.initializer_range),
+ )
+ elif self.position_embedding_type == "rotary":
+ self.rotary_embeddings = TFRotaryEmbedding(dim=self.attention_head_size, name="rotary_embeddings")
+
+ self.is_decoder = config.is_decoder
+ self.config = config
+
+ def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
+ new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size]
+ x = tf.reshape(x, new_x_shape)
+ return tf.transpose(x, perm=(0, 2, 1, 3))
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ encoder_hidden_states: tf.Tensor | None = None,
+ encoder_attention_mask: tf.Tensor | None = None,
+ past_key_value: tuple[tuple[tf.Tensor]] | None = None,
+ output_attentions: bool | None = False,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
+ value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
+ # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
+ # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
+ # ESM code and fix rotary embeddings.
+ query_layer = query_layer * self.attention_head_size**-0.5
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ if self.position_embedding_type == "rotary":
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = shape_list(hidden_states)[1]
+ position_ids_l = tf.expand_dims(tf.range(seq_length, dtype=tf.int64), -1)
+ position_ids_r = tf.expand_dims(tf.range(seq_length, dtype=tf.int64), 0)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = tf.cast(positional_embedding, query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = tf.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in EsmModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = stable_softmax(attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = attention_probs @ value_layer
+
+ context_layer = tf.transpose(context_layer, perm=(0, 2, 1, 3))
+ new_context_layer_shape = shape_list(context_layer)[:-2] + [self.all_head_size]
+ context_layer = tf.reshape(context_layer, new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "query", None) is not None:
+ with tf.name_scope(self.query.name):
+ self.query.build([None, None, self.config.hidden_size])
+ if getattr(self, "key", None) is not None:
+ with tf.name_scope(self.key.name):
+ self.key.build([None, None, self.config.hidden_size])
+ if getattr(self, "value", None) is not None:
+ with tf.name_scope(self.value.name):
+ self.value.build([None, None, self.config.hidden_size])
+ if getattr(self, "rotary_embeddings", None) is not None:
+ with tf.name_scope(self.rotary_embeddings.name):
+ self.rotary_embeddings.build(None)
+
+
+class TFEsmSelfOutput(keras.layers.Layer):
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.dense = keras.layers.Dense(
+ config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states, input_tensor, training=False):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states += input_tensor
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFEsmAttention(keras.layers.Layer):
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.self = TFEsmSelfAttention(config, name="self")
+ self.output_layer = TFEsmSelfOutput(config, name="output")
+ self.pruned_heads = set()
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.config = config
+
+ def prune_heads(self, heads):
+ raise NotImplementedError
+
+ def call(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ training=False,
+ ):
+ hidden_states_ln = self.LayerNorm(hidden_states)
+ self_outputs = self.self(
+ hidden_states_ln,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ training,
+ )
+ attention_output = self.output_layer(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "self", None) is not None:
+ with tf.name_scope(self.self.name):
+ self.self.build(None)
+ if getattr(self, "output_layer", None) is not None:
+ with tf.name_scope(self.output_layer.name):
+ self.output_layer.build(None)
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+class TFEsmIntermediate(keras.layers.Layer):
+ def __init__(self, config: EsmConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.intermediate_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="dense",
+ )
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = tf.nn.gelu(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFEsmOutput(keras.layers.Layer):
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.dense = keras.layers.Dense(
+ config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states, input_tensor, training=False):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states += input_tensor
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.intermediate_size])
+
+
+class TFEsmLayer(keras.layers.Layer):
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = TFEsmAttention(config, name="attention")
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = TFEsmAttention(config)
+ self.intermediate = TFEsmIntermediate(config, name="intermediate")
+ self.output_layer = TFEsmOutput(config, name="output")
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.config = config
+
+ def call(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ training=False,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ training=training,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise AttributeError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
+ " with cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ cross_attn_past_key_value,
+ output_attentions,
+ training=training,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ layernorm_output = self.LayerNorm(attention_output)
+ intermediate_output = self.intermediate(hidden_states=layernorm_output)
+ layer_output = self.output_layer(
+ hidden_states=intermediate_output, input_tensor=attention_output, training=training
+ )
+ outputs = (layer_output,) + outputs # add attentions if we output them
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "intermediate", None) is not None:
+ with tf.name_scope(self.intermediate.name):
+ self.intermediate.build(None)
+ if getattr(self, "output_layer", None) is not None:
+ with tf.name_scope(self.output_layer.name):
+ self.output_layer.build(None)
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+class TFEsmEncoder(keras.layers.Layer):
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.config = config
+ self.layer = [TFEsmLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
+ self.emb_layer_norm_after = keras.layers.LayerNormalization(
+ epsilon=config.layer_norm_eps, name="emb_layer_norm_after"
+ )
+
+ def call(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ training=False,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ training,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if self.emb_layer_norm_after:
+ hidden_states = self.emb_layer_norm_after(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return TFBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "emb_layer_norm_after", None) is not None:
+ with tf.name_scope(self.emb_layer_norm_after.name):
+ self.emb_layer_norm_after.build([None, None, self.config.hidden_size])
+ if getattr(self, "layer", None) is not None:
+ for layer in self.layer:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Esm
+class TFEsmPooler(keras.layers.Layer):
+ def __init__(self, config: EsmConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation="tanh",
+ name="dense",
+ )
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(inputs=first_token_tensor)
+
+ return pooled_output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFEsmPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = EsmConfig
+ base_model_prefix = "esm"
+
+
+ESM_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Keras [Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it as a
+ regular Keras model and refer to the TF/Keras documentation for all matters related to general usage and behavior.
+
+ Parameters:
+ config ([`EsmConfig`]): Model configuration class with all the parameters of the
+ model. Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ESM_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`tf.Tensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`tf.Tensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
+ ESM_START_DOCSTRING,
+)
+class TFEsmMainLayer(keras.layers.Layer):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
+ all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+ """
+
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config, add_pooling_layer=True, name=None, **kwargs):
+ super().__init__(name=name, **kwargs)
+
+ self.config = config
+ self.is_decoder = config.is_decoder
+
+ self.embeddings = TFEsmEmbeddings(config, name="embeddings")
+ self.encoder = TFEsmEncoder(config, name="encoder")
+ self.pooler = TFEsmPooler(config, name="pooler") if add_pooling_layer else None
+
+ self.contact_head = TFEsmContactPredictionHead(
+ in_features=self.config.num_hidden_layers * self.config.num_attention_heads, bias=True, name="contact_head"
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "pooler", None) is not None:
+ with tf.name_scope(self.pooler.name):
+ self.pooler.build(None)
+ if getattr(self, "contact_head", None) is not None:
+ with tf.name_scope(self.contact_head.name):
+ self.contact_head.build(None)
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value: tf.Variable):
+ self.embeddings.word_embeddings.weight = value
+ self.embeddings.vocab_size = shape_list(value)[0]
+
+ def _prune_heads(self, heads_to_prune):
+ raise NotImplementedError
+
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]:
+ if not self.config.is_decoder:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+
+ if past_key_values is None:
+ past_key_values_length = 0
+ past_key_values = [None] * len(self.encoder.layer)
+ else:
+ past_key_values_length = shape_list(past_key_values[0][0])[-2]
+
+ if attention_mask is None:
+ attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ training=training,
+ )
+
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask_shape = shape_list(attention_mask)
+
+ mask_seq_length = seq_length + past_key_values_length
+ # Copied from `modeling_tf_t5.py`
+ # Provided a padding mask of dimensions [batch_size, mask_seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
+ if self.is_decoder:
+ seq_ids = tf.range(mask_seq_length)
+ causal_mask = tf.less_equal(
+ tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
+ seq_ids[None, :, None],
+ )
+ causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
+ extended_attention_mask = causal_mask * attention_mask[:, None, :]
+ attention_mask_shape = shape_list(extended_attention_mask)
+ extended_attention_mask = tf.reshape(
+ extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
+ )
+ if past_key_values[0] is not None:
+ # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
+ extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
+ else:
+ extended_attention_mask = tf.reshape(
+ attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
+ one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
+ ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
+ extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
+
+ # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
+ if self.is_decoder and encoder_attention_mask is not None:
+ # If a 2D ou 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype)
+ num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
+ if num_dims_encoder_attention_mask == 3:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+ if num_dims_encoder_attention_mask == 2:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
+ # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
+ # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
+
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.config.num_hidden_layers
+
+ encoder_outputs = self.encoder(
+ hidden_states=embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (
+ sequence_output,
+ pooled_output,
+ ) + encoder_outputs[1:]
+
+ return TFBaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+ def predict_contacts(self, tokens, attention_mask):
+ attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
+ attns = tf.stack(attns, axis=1) # Matches the original model layout
+ # In the original model, attentions for padding tokens are completely zeroed out.
+ # This makes no difference most of the time because the other tokens won't attend to them,
+ # but it does for the contact prediction task, which takes attentions as input,
+ # so we have to mimic that here.
+ attention_mask = tf.cast(attention_mask, attns.dtype)
+ attns *= attention_mask[:, None, None, None]
+ attns *= attention_mask[:, None, None, :, None]
+ return self.contact_head(tokens, attns)
+
+
+@add_start_docstrings(
+ "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
+ ESM_START_DOCSTRING,
+)
+class TFEsmModel(TFEsmPreTrainedModel):
+ def __init__(self, config: EsmConfig, add_pooling_layer=True, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.esm = TFEsmMainLayer(config, add_pooling_layer=add_pooling_layer, name="esm")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]:
+ r"""
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`)
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`). Set to `False` during training, `True` during generation
+ """
+ outputs = self.esm(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ return outputs
+
+ def predict_contacts(self, tokens, attention_mask):
+ return self.esm.predict_contacts(tokens, attention_mask)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "esm", None) is not None:
+ with tf.name_scope(self.esm.name):
+ self.esm.build(None)
+
+
+@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
+class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
+ self.lm_head = TFEsmLMHead(config, name="lm_head")
+ if config.tie_word_embeddings:
+ # Ensure word embeddings are built so that we actually have something to tie
+ with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")):
+ self.esm.embeddings.word_embeddings.build((None, None))
+ self.lm_head.decoder = self.esm.embeddings.word_embeddings.weights[0]
+
+ def get_output_embeddings(self):
+ return self.lm_head.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head.decoder = new_embeddings
+
+ def get_lm_head(self):
+ return self.lm_head
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFMaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ mask="",
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFMaskedLMOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ kwargs (`dict[str, any]`, *optional*, defaults to `{}`):
+ Used to hide legacy arguments that have been deprecated.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.esm(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ prediction_scores = self.lm_head(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ masked_lm_loss = self.hf_compute_loss(labels=labels, logits=prediction_scores)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return TFMaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def predict_contacts(self, tokens, attention_mask):
+ return self.esm.predict_contacts(tokens, attention_mask)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "esm", None) is not None:
+ with tf.name_scope(self.esm.name):
+ self.esm.build(None)
+ if getattr(self, "lm_head", None) is not None:
+ with tf.name_scope(self.lm_head.name):
+ self.lm_head.build(None)
+
+
+class TFEsmLMHead(keras.layers.Layer):
+ """ESM Head for masked language modeling."""
+
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.dense = keras.layers.Dense(
+ config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+ if config.tie_word_embeddings:
+ self.decoder = None
+ else:
+ self.decoder = keras.layers.Dense(
+ config.vocab_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="decoder",
+ use_bias=False,
+ )
+ self.config = config
+
+ def build(self, input_shape=None):
+ # Separate bias to match the PT model and allow weight cross-loading to work
+ # Put it in the build so it gets the right name when adding it as a weight
+ if self.built:
+ return
+ self.built = True
+ self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.hidden_size])
+ if getattr(self, "decoder", None) is not None and not self.config.tie_word_embeddings:
+ with tf.name_scope(self.decoder.name):
+ self.decoder.build([None, None, self.config.hidden_size])
+
+ def get_bias(self):
+ return {"bias": self.bias}
+
+ def call(self, features):
+ x = self.dense(features)
+ x = tf.nn.gelu(x)
+ x = self.layer_norm(x)
+
+ # project back to size of vocabulary with bias
+ if self.config.tie_word_embeddings:
+ x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
+ else:
+ x = self.decoder(x) + self.bias
+ return x
+
+
+@add_start_docstrings(
+ """
+ ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """,
+ ESM_START_DOCSTRING,
+)
+class TFEsmForSequenceClassification(TFEsmPreTrainedModel, TFSequenceClassificationLoss):
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
+ self.classifier = TFEsmClassificationHead(config, name="classifier")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.esm(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ logits = self.classifier(sequence_output)
+
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "esm", None) is not None:
+ with tf.name_scope(self.esm.name):
+ self.esm.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build(None)
+
+
+@add_start_docstrings(
+ """
+ ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ ESM_START_DOCSTRING,
+)
+class TFEsmForTokenClassification(TFEsmPreTrainedModel, TFTokenClassificationLoss):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+ self.classifier = keras.layers.Dense(config.num_labels, name="classifier")
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFTokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFTokenClassifierOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.esm(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output, training=training)
+ logits = self.classifier(sequence_output)
+
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFTokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "esm", None) is not None:
+ with tf.name_scope(self.esm.name):
+ self.esm.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+class TFEsmClassificationHead(keras.layers.Layer):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.dense = keras.layers.Dense(
+ config.hidden_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation="tanh",
+ name="dense",
+ )
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+ self.out_proj = keras.layers.Dense(
+ config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation="linear",
+ name="out_proj",
+ )
+ self.config = config
+
+ def call(self, features, training=False):
+ x = features[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x, training=training)
+ x = self.dense(x)
+ x = self.dropout(x, training=training)
+ x = self.out_proj(x)
+ return x
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "out_proj", None) is not None:
+ with tf.name_scope(self.out_proj.name):
+ self.out_proj.build([None, None, self.config.hidden_size])
+
+
+def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+ are ignored. This is modified from fairseq's `utils.make_positions`.
+
+ Args:
+ x: tf.Tensor x:
+
+ Returns: tf.Tensor
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = tf.cast(input_ids != padding_idx, tf.int64)
+ incremental_indices = (tf.cumsum(mask, axis=1) + past_key_values_length) * mask
+ return incremental_indices + padding_idx
+
+
+__all__ = [
+ "TFEsmForMaskedLM",
+ "TFEsmForSequenceClassification",
+ "TFEsmForTokenClassification",
+ "TFEsmModel",
+ "TFEsmPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/tokenization_esm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/tokenization_esm.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d9705f7dbd33216a327eab04415ec57fe8e858d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/tokenization_esm.py
@@ -0,0 +1,147 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for ESM."""
+
+import os
+from typing import Optional
+
+from ...tokenization_utils import PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+
+def load_vocab_file(vocab_file):
+ with open(vocab_file, "r") as f:
+ lines = f.read().splitlines()
+ return [l.strip() for l in lines]
+
+
+class EsmTokenizer(PreTrainedTokenizer):
+ """
+ Constructs an ESM tokenizer.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ unk_token="",
+ cls_token="",
+ pad_token="",
+ mask_token="",
+ eos_token="",
+ **kwargs,
+ ):
+ self.all_tokens = load_vocab_file(vocab_file)
+ self._id_to_token = dict(enumerate(self.all_tokens))
+ self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}
+ super().__init__(
+ unk_token=unk_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ eos_token=eos_token,
+ **kwargs,
+ )
+
+ # TODO, all the tokens are added? But they are also part of the vocab... bit strange.
+ # none of them are special, but they all need special splitting.
+
+ self.unique_no_split_tokens = self.all_tokens
+ self._update_trie(self.unique_no_split_tokens)
+
+ def _convert_id_to_token(self, index: int) -> str:
+ return self._id_to_token.get(index, self.unk_token)
+
+ def _convert_token_to_id(self, token: str) -> int:
+ return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
+
+ def _tokenize(self, text, **kwargs):
+ return text.split()
+
+ def get_vocab(self):
+ base_vocab = self._token_to_id.copy()
+ base_vocab.update(self.added_tokens_encoder)
+ return base_vocab
+
+ def token_to_id(self, token: str) -> int:
+ return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
+
+ def id_to_token(self, index: int) -> str:
+ return self._id_to_token.get(index, self.unk_token)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ cls = [self.cls_token_id]
+ sep = [self.eos_token_id] # No sep token in ESM vocabulary
+ if token_ids_1 is None:
+ if self.eos_token_id is None:
+ return cls + token_ids_0
+ else:
+ return cls + token_ids_0 + sep
+ elif self.eos_token_id is None:
+ raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!")
+ return cls + token_ids_0 + sep + token_ids_1 + sep # Multiple inputs always have an EOS token
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list, token_ids_1: Optional[list] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of ids of the first sequence.
+ token_ids_1 (`list[int]`, *optional*):
+ List of ids of the second sequence.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if the provided sequence of "
+ "ids is already formatted with special tokens for the model."
+ )
+
+ return [1 if token in self.all_special_ids else 0 for token in token_ids_0]
+ mask = [1] + ([0] * len(token_ids_0)) + [1]
+ if token_ids_1 is not None:
+ mask += [0] * len(token_ids_1) + [1]
+ return mask
+
+ def save_vocabulary(self, save_directory, filename_prefix):
+ vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt")
+ with open(vocab_file, "w") as f:
+ f.write("\n".join(self.all_tokens))
+ return (vocab_file,)
+
+ @property
+ def vocab_size(self) -> int:
+ return len(self.all_tokens)
+
+
+__all__ = ["EsmTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..202147c938465dd7dfcb7e79ecbeeb93ce632dbf
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_falcon_mamba import *
+ from .modeling_falcon_mamba import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/configuration_falcon_mamba.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/configuration_falcon_mamba.py
new file mode 100644
index 0000000000000000000000000000000000000000..7630ebd6343ac968303fc0c31f2742bb352b4f8a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/configuration_falcon_mamba.py
@@ -0,0 +1,170 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_falcon_mamba.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+from ...configuration_utils import PretrainedConfig
+
+
+class FalconMambaConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the FALCON_MAMBA
+ [tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50280):
+ Vocabulary size of the FALCON_MAMBA model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`FalconMambaModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the embeddings and hidden states.
+ state_size (`int`, *optional*, defaults to 16): shape of the state space latents.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the model.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
+ The epsilon to use in the layer normalization layers.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 0):
+ The id of the beginning of sentence token in the vocabulary.
+ eos_token_id (`int`, *optional*, defaults to 0):
+ The id of the end of sentence token in the vocabulary.
+ expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
+ conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
+ use_bias (`bool`, *optional*, defaults to `False`):
+ Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
+ use_conv_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not to use bias in the convolution layer of the mixer block.
+ hidden_act (`str`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ initializer_range (`float`, *optional*, defaults to 0.1):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ residual_in_fp32 (`bool`, *optional*, defaults to `True`):
+ Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
+ time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
+ Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
+ time_step_scale (`float`, *optional*, defaults to 1.0):
+ Scale used used to scale `dt_proj.bias`.
+ time_step_min (`float`, *optional*, defaults to 0.001):
+ Minimum `time_step` used to bound `dt_proj.bias`.
+ time_step_max (`float`, *optional*, defaults to 0.1):
+ Maximum `time_step` used to bound `dt_proj.bias`.
+ time_step_init_scheme (`float`, *optional*, defaults to `"random"`):
+ Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`
+ time_step_floor (`float`, *optional*, defaults to 0.0001):
+ Minimum clamping value of the `dt_proj.bias` layer initialization.
+ rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
+ Whether or not to rescale `out_proj` weights when initializing.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the cache should be used.
+ use_falcon_mambapy (`bool`, *optional*, defaults to `False`):
+ This argument corresponds to `use_mambapy` in MambaConfig.
+ Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not available. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
+ mixer_rms_eps (`float`, *optional*, defaults to 1e-06):
+ The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states.
+
+
+ Example:
+
+ ```python
+ >>> from transformers import FalconMambaConfig, FalconMambaModel
+
+ >>> # Initializing a FalconMamba configuration
+ >>> configuration = FalconMambaConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = FalconMambaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "falcon_mamba"
+
+ def __init__(
+ self,
+ vocab_size=50280,
+ hidden_size=768,
+ state_size=16,
+ num_hidden_layers=32,
+ layer_norm_epsilon=1e-5,
+ pad_token_id=0,
+ bos_token_id=0,
+ eos_token_id=0,
+ expand=2,
+ conv_kernel=4,
+ use_bias=False,
+ use_conv_bias=True,
+ hidden_act="silu",
+ initializer_range=0.1,
+ residual_in_fp32=True,
+ time_step_rank="auto",
+ time_step_scale=1.0,
+ time_step_min=0.001,
+ time_step_max=0.1,
+ time_step_init_scheme="random",
+ time_step_floor=1e-4,
+ rescale_prenorm_residual=False,
+ use_cache=True,
+ use_falcon_mambapy=False,
+ mixer_rms_eps=1e-6,
+ **kwargs,
+ ):
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.state_size = state_size
+ self.num_hidden_layers = num_hidden_layers
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.conv_kernel = conv_kernel
+ self.expand = expand
+ # This is needed since mamba overrides the intermediate_size attribute
+ self.intermediate_size = (
+ int(expand * self.hidden_size)
+ if kwargs.get("intermediate_size") is None
+ else kwargs.get("intermediate_size")
+ )
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.use_bias = use_bias
+ self.use_conv_bias = use_conv_bias
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
+ self.time_step_scale = time_step_scale
+ self.time_step_min = time_step_min
+ self.time_step_max = time_step_max
+ self.time_step_init_scheme = time_step_init_scheme
+ self.time_step_floor = time_step_floor
+ self.rescale_prenorm_residual = rescale_prenorm_residual
+ self.residual_in_fp32 = residual_in_fp32
+ self.use_cache = use_cache
+ self.use_falcon_mambapy = use_falcon_mambapy
+ self.mixer_rms_eps = mixer_rms_eps
+
+
+__all__ = ["FalconMambaConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/modeling_falcon_mamba.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cdf6da7bda376a34dd00545d96c44a99fc0e660
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/modeling_falcon_mamba.py
@@ -0,0 +1,937 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_falcon_mamba.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Any, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...configuration_utils import PretrainedConfig
+from ...generation import GenerationMixin
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_utils import PreTrainedModel
+from ...utils import ModelOutput, auto_docstring, logging
+from ...utils.import_utils import (
+ is_causal_conv1d_available,
+ is_kernels_available,
+ is_mamba_ssm_available,
+ is_mambapy_available,
+)
+from .configuration_falcon_mamba import FalconMambaConfig
+
+
+if is_mambapy_available():
+ from mambapy.pscan import pscan
+else:
+ pscan = None
+
+if is_mamba_ssm_available():
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
+
+ from ...kernels.falcon_mamba import mamba_inner_fn
+else:
+ selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
+
+
+logger = logging.get_logger(__name__)
+
+
+class FalconMambaCache:
+ """
+ Cache for falcon_mamba model which does not have attention mechanism and key value states.
+
+ Arguments:
+ config (`PretrainedConfig):
+ The configuration file defining the shape-related attributes required to initialize the static cache.
+ max_batch_size (`int`):
+ The maximum batch size with which the model will be used. Note that a new instance must be instantiated if
+ a smaller batch size is used.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
+ The default `dtype` to use when initializing the layer.
+ device (`torch.device` or `str`, *optional*):
+ The device on which the cache should be initialized. Should be the same as the layer.
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache
+
+ >>> model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b")
+
+ >>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> cache_params = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
+ >>> cache_position = torch.arange(len(inputs["input_ids"][0]), device=model.device) # sequence length
+ >>> outputs = model(**inputs, cache_params=cache_params, cache_position=cache_position, use_cache=True)
+ >>> outputs.cache_params
+ ```
+ """
+
+ is_compileable = True
+
+ # TODO (joao): add layer_device_map arg and update code in `generate` accordingly
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ max_batch_size: int,
+ dtype: torch.dtype = torch.float16,
+ device: Union[torch.device, str, None] = None,
+ ):
+ self.max_batch_size = max_batch_size
+ self._dtype = dtype
+ self.intermediate_size = config.intermediate_size
+ self.ssm_state_size = config.state_size
+ self.conv_kernel_size = config.conv_kernel
+
+ self.conv_states: list[torch.Tensor] = []
+ self.ssm_states: list[torch.Tensor] = []
+ device = torch.device(device) if device is not None else None
+ for _ in range(config.num_hidden_layers):
+ conv_state: torch.Tensor = torch.zeros(
+ self.max_batch_size,
+ self.intermediate_size,
+ self.conv_kernel_size,
+ device=device,
+ dtype=self._dtype,
+ )
+ ssm_state: torch.Tensor = torch.zeros(
+ self.max_batch_size,
+ self.intermediate_size,
+ self.ssm_state_size,
+ device=device,
+ dtype=self._dtype,
+ )
+
+ torch._dynamo.mark_static_address(conv_state)
+ torch._dynamo.mark_static_address(ssm_state)
+ self.conv_states.append(conv_state)
+ self.ssm_states.append(ssm_state)
+
+ def update_conv_state(
+ self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
+ ) -> torch.Tensor:
+ # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
+ # when the cache is initialized in the forward pass (e.g. FalconMamba)
+ if self.conv_states[layer_idx].device != new_conv_state.device:
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)
+
+ conv_state = self.conv_states[layer_idx]
+ cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
+
+ conv_state = conv_state.roll(shifts=-1, dims=-1)
+ conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
+ self.conv_states[layer_idx].zero_()
+ self.conv_states[layer_idx] += conv_state
+ return self.conv_states[layer_idx]
+
+ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
+ self.ssm_states[layer_idx].zero_()
+ self.ssm_states[layer_idx] += new_ssm_state.to(self.ssm_states[layer_idx].device)
+ return self.ssm_states[layer_idx]
+
+ def reset(self):
+ for layer_idx in range(len(self.conv_states)):
+ # In-place ops prevent breaking the static address
+ self.conv_states[layer_idx].zero_()
+ self.ssm_states[layer_idx].zero_()
+
+
+def _lazy_load_causal_conv1d():
+ global _causal_conv1d_cache
+ if _causal_conv1d_cache is not None:
+ return _causal_conv1d_cache
+
+ if is_kernels_available():
+ from kernels import get_kernel
+
+ _causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
+ _causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
+ elif is_causal_conv1d_available():
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
+
+ _causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
+ else:
+ _causal_conv1d_cache = (None, None)
+ return _causal_conv1d_cache
+
+
+_causal_conv1d_cache = None
+
+
+def rms_forward(hidden_states, variance_epsilon=1e-6):
+ """
+ Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
+ leverage this in order to multiply the final result with the RMSNorm weight
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Hidden states to normalize
+ variance_epsilon (`float`):
+ The eps value to add in the square root scaling factor
+ """
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
+ return hidden_states.to(input_dtype)
+
+
+class FalconMambaMixer(nn.Module):
+ """
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
+ A, D are input independent (see FalconMamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
+ ∆, B, C are input-dependent (this is a key difference between FalconMamba and the linear time invariant S4,
+ and is why FalconMamba is called **selective** state spaces)
+ """
+
+ def __init__(self, config: FalconMambaConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.ssm_state_size = config.state_size
+ self.conv_kernel_size = config.conv_kernel
+ self.intermediate_size = config.intermediate_size
+ self.time_step_rank = int(config.time_step_rank)
+ self.layer_idx = layer_idx
+ self.use_conv_bias = config.use_conv_bias
+ self.conv1d = nn.Conv1d(
+ in_channels=self.intermediate_size,
+ out_channels=self.intermediate_size,
+ bias=config.use_conv_bias,
+ kernel_size=config.conv_kernel,
+ groups=self.intermediate_size,
+ padding=config.conv_kernel - 1,
+ )
+
+ self.activation = config.hidden_act
+ self.act = ACT2FN[config.hidden_act]
+
+ self.use_falcon_mambapy = config.use_falcon_mambapy
+
+ # projection of the input hidden states
+ self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
+ # selective projection used to make dt, B and C input dependent
+ self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
+ # time step projection (discretization)
+ self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
+
+ # S4D real initialization. These are not discretized!
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
+ A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
+ A = A.expand(self.intermediate_size, -1).contiguous()
+
+ self.A_log = nn.Parameter(torch.log(A))
+ self.D = nn.Parameter(torch.ones(self.intermediate_size))
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
+ self.use_bias = config.use_bias
+
+ self.warn_slow_implementation()
+ # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
+ self.register_buffer(
+ "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
+ )
+ self.register_buffer(
+ "dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False
+ )
+ self.rms_eps = config.mixer_rms_eps
+
+ def warn_slow_implementation(self):
+ causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
+ is_fast_path_available = all(
+ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
+ )
+ if not is_fast_path_available:
+ if self.use_falcon_mambapy:
+ if is_mambapy_available():
+ logger.warning_once(
+ "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
+ " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
+ " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d"
+ )
+ else:
+ raise ImportError(
+ "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
+ )
+ else:
+ logger.warning_once(
+ "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
+ " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
+ " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
+ )
+
+ def cuda_kernels_forward(
+ self,
+ hidden_states: torch.Tensor,
+ cache_params: Optional[FalconMambaCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ # 1. Gated MLP's linear projection
+ projected_states = self.in_proj(hidden_states).transpose(1, 2)
+
+ if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
+ contextualized_states = mamba_inner_fn(
+ projected_states,
+ self.conv1d.weight,
+ self.conv1d.bias if self.use_conv_bias else None,
+ self.x_proj.weight,
+ self.dt_proj.weight,
+ self.out_proj.weight,
+ self.out_proj.bias.float() if self.use_bias else None,
+ -torch.exp(self.A_log.float()),
+ None, # input-dependent B
+ None, # input-dependent C
+ self.D.float(),
+ delta_bias=self.dt_proj.bias.float(),
+ delta_softplus=True,
+ b_rms_weight=self.b_c_rms,
+ c_rms_weight=self.b_c_rms,
+ dt_rms_weight=self.dt_rms,
+ b_c_dt_rms_eps=self.rms_eps,
+ )
+
+ else:
+ causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
+ hidden_states, gate = projected_states.chunk(2, dim=1)
+
+ if attention_mask is not None:
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+ # 2. Convolution sequence transformation
+ conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
+ if cache_params is not None and cache_position[0] > 0:
+ hidden_states = causal_conv1d_update(
+ hidden_states.squeeze(-1),
+ cache_params.conv_states[self.layer_idx],
+ conv_weights,
+ self.conv1d.bias,
+ self.activation,
+ )
+ hidden_states = hidden_states.unsqueeze(-1)
+ else:
+ if cache_params is not None:
+ conv_states = nn.functional.pad(
+ hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
+ )
+ cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
+ hidden_states = causal_conv1d_fn(
+ hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
+ )
+
+ if attention_mask is not None:
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+ # 3. State Space Model sequence transformation
+ # 3.a. input varying initialization of time_step, B and C
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
+ time_step, B, C = torch.split(
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
+ )
+
+ B = rms_forward(B, variance_epsilon=self.rms_eps)
+ C = rms_forward(C, variance_epsilon=self.rms_eps)
+ time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
+
+ # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
+ # at the price of a small overhead.
+ if hasattr(self.config, "_pre_quantization_dtype"):
+ discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
+ else:
+ discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
+
+ A = -torch.exp(self.A_log.float())
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
+ time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
+ if cache_params is not None and cache_position[0] > 0:
+ scan_outputs = selective_state_update(
+ cache_params.ssm_states[self.layer_idx],
+ hidden_states[..., 0],
+ discrete_time_step[..., 0],
+ A,
+ B[:, 0],
+ C[:, 0],
+ self.D,
+ gate[..., 0],
+ time_proj_bias,
+ dt_softplus=True,
+ ).unsqueeze(-1)
+ else:
+ scan_outputs, ssm_state = selective_scan_fn(
+ hidden_states,
+ discrete_time_step,
+ A,
+ B.transpose(1, 2),
+ C.transpose(1, 2),
+ self.D.float(),
+ gate,
+ time_proj_bias,
+ delta_softplus=True,
+ return_last_state=True,
+ )
+ if ssm_state is not None and cache_params is not None:
+ cache_params.update_ssm_state(self.layer_idx, ssm_state)
+
+ # 4. Final linear projection
+ contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
+ return contextualized_states
+
+ # fmt: off
+ def slow_forward(self,
+ input_states,
+ cache_params: Optional[FalconMambaCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ batch_size, seq_len, _ = input_states.shape
+ dtype = input_states.dtype
+ # 1. Gated MLP's linear projection
+ projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
+ hidden_states, gate = projected_states.chunk(2, dim=1)
+
+ if attention_mask is not None:
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+ # 2. Convolution sequence transformation
+ if cache_params is not None:
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
+ ssm_state = ssm_state.to(hidden_states.device)
+ # use `cache_position.shape[0]` to check whether we are in prefill
+ # stage, it's equivalent to check `cache_position[0] == 0`, which
+ # breaks dynamo fullgraph constraints
+ if cache_position is not None and cache_position.shape[0] == self.conv_kernel_size:
+ conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
+
+ cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
+ hidden_states = self.act(
+ self.conv1d(hidden_states)[..., :seq_len]
+ ) # [batch, intermediate_size, seq_len]
+ else:
+ conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
+ conv_state = conv_state.to(self.conv1d.weight.device)
+ hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
+ if self.use_conv_bias:
+ hidden_states += self.conv1d.bias
+ hidden_states = (
+ self.act(hidden_states).to(dtype).unsqueeze(-1)
+ ) # [batch, intermediate_size, 1] : decoding
+ else:
+ ssm_state = torch.zeros(
+ (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype
+ )
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
+
+ if attention_mask is not None:
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+ # 3. State Space Model sequence transformation
+ # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
+ time_step, B, C = torch.split(
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
+ )
+
+ B = rms_forward(B, variance_epsilon=self.rms_eps)
+ C = rms_forward(C, variance_epsilon=self.rms_eps)
+ time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
+
+ discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
+ discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(
+ 1, 2
+ ) # [batch, intermediate_size, seq_len]
+
+ # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
+ A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
+ discrete_A = torch.exp(
+ A[None, :, None, :] * discrete_time_step[:, :, :, None]
+ ) # [batch, intermediate_size, seq_len, ssm_state_size]
+ discrete_B = (
+ discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
+ ) # [batch, intermediate_size, seq_len, ssm_state_size]
+ deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
+
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
+ if self.use_falcon_mambapy and self.training and cache_params is None:
+ hs = pscan(
+ discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)
+ ) # [batch, seq_len, intermediate_size, ssm_state_size]
+ scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len]
+ scan_output = scan_output + hidden_states * self.D[None, :, None]
+ scan_output = scan_output * self.act(gate)
+ else:
+ scan_outputs = []
+ for i in range(seq_len):
+ ssm_state = (
+ discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
+ ) # [batch, intermediate_size, ssm_state]
+ scan_output = torch.matmul(
+ ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)
+ ) # [batch, intermediate_size, 1]
+ scan_outputs.append(scan_output[:, :, 0])
+ scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
+ scan_output = scan_output + (hidden_states * self.D[None, :, None])
+ scan_output = scan_output * self.act(gate)
+
+ if cache_params is not None:
+ cache_params.update_ssm_state(self.layer_idx, ssm_state)
+
+ # 4. Final linear projection
+ contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
+ return contextualized_states
+ # fmt: on
+
+ def forward(
+ self,
+ hidden_states,
+ cache_params: Optional[FalconMambaCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
+ is_fast_path_available = all(
+ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
+ )
+ if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
+ return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
+ return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
+
+
+class FalconMambaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ FalconMambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ return self.weight.to(hidden_states.device) * rms_forward(
+ hidden_states, variance_epsilon=self.variance_epsilon
+ )
+
+ def extra_repr(self):
+ return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"
+
+
+class FalconMambaBlock(GradientCheckpointingLayer):
+ def __init__(self, config, layer_idx):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.residual_in_fp32 = config.residual_in_fp32
+ self.norm = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+ self.mixer = FalconMambaMixer(config, layer_idx=layer_idx)
+
+ def forward(
+ self,
+ hidden_states,
+ cache_params: Optional[FalconMambaCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ residual = hidden_states
+ hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
+ if self.residual_in_fp32:
+ residual = residual.to(torch.float32)
+
+ hidden_states = self.mixer(
+ hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
+ )
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class FalconMambaPreTrainedModel(PreTrainedModel):
+ config: FalconMambaConfig
+ base_model_prefix = "backbone"
+ _no_split_modules = ["FalconMambaBlock", "FalconMambaMixer"]
+ supports_gradient_checkpointing = True
+ _is_stateful = True
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ std = self.config.initializer_range
+ if isinstance(module, FalconMambaMixer):
+ # S4D real initialization. These are not discretized!
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
+ A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :]
+ A = A.expand(module.intermediate_size, -1).contiguous()
+ module.A_log.copy_(torch.log(A))
+ module.D.data.fill_(1.0)
+
+ dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
+ if self.config.time_step_init_scheme == "constant":
+ nn.init.constant_(module.dt_proj.weight, dt_init_std)
+ elif self.config.time_step_init_scheme == "random":
+ nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
+
+ dt = torch.exp(
+ torch.rand(self.config.intermediate_size)
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
+ + math.log(self.config.time_step_min)
+ ).clamp(min=self.config.time_step_floor)
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
+ module.dt_proj.bias.copy_(inv_dt)
+ module.dt_proj.bias._no_reinit = True
+
+ nn.init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5))
+ if module.conv1d.bias is not None:
+ if not getattr(module.conv1d.bias, "_no_reinit", False):
+ nn.init.zeros_(module.conv1d.bias)
+ nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5))
+
+ if self.config.rescale_prenorm_residual:
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
+ #
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
+ # We need to reinit p since this code could be called multiple times
+ # Having just p *= scale would repeatedly scale it down
+ p = module.out_proj.weight
+ p /= math.sqrt(self.config.num_hidden_layers)
+
+ if isinstance(module, nn.Linear):
+ if not getattr(module.weight, "_no_reinit", False):
+ nn.init.normal_(module.weight, std=std)
+ if module.bias is not None:
+ if not getattr(module.bias, "_no_reinit", False):
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, FalconMambaRMSNorm):
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Embedding):
+ nn.init.normal_(module.weight, std=std)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Class for the FALCON_MAMBA model outputs.
+ """
+)
+class FalconMambaOutput(ModelOutput):
+ r"""
+ cache_params (`FalconMambaCache`):
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
+ avoid providing the old `input_ids`.
+
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ cache_params: Optional[FalconMambaCache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for causal language model (or autoregressive) outputs.
+ """
+)
+class FalconMambaCausalLMOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ cache_params (`FalconMambaCache`):
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
+ avoid providing the old `input_ids`.
+
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ cache_params: Optional[FalconMambaCache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+@auto_docstring
+class FalconMambaModel(FalconMambaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.layers = nn.ModuleList(
+ [FalconMambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]
+ )
+
+ self.gradient_checkpointing = False
+ self.norm_f = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def set_input_embeddings(self, new_embeddings):
+ self.embeddings = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ cache_params: Optional[FalconMambaCache] = None,
+ use_cache: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ) -> Union[tuple, FalconMambaOutput]:
+ r"""
+ cache_params (`FalconMambaCache`, *optional*):
+ If passed along, the model uses the previous state in all the blocks (which will give the output for the
+ `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
+ use_cache (`bool`, *optional*):
+ If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
+ """
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embeddings(input_ids)
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ use_cache = False
+
+ if use_cache:
+ if cache_params is None:
+ cache_params = FalconMambaCache(
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
+ )
+ cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
+ elif cache_position is None:
+ # cases when we do manual forward instead of using `model.generate` which will initiate
+ # `cache_position` and makes sure it is not None, throw error here instead of doing some
+ # hack to conjecture the current cache position
+ raise ValueError(
+ "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
+ "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
+ "be initialized for you automatically"
+ )
+ else:
+ cache_params = None
+
+ hidden_states = inputs_embeds
+ all_hidden_states = () if output_hidden_states else None
+ for mixer_block in self.layers:
+ hidden_states = mixer_block(
+ hidden_states,
+ cache_params=cache_params,
+ cache_position=cache_position,
+ attention_mask=attention_mask,
+ )
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states = self.norm_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
+
+ return FalconMambaOutput(
+ last_hidden_state=hidden_states,
+ cache_params=cache_params if use_cache else None,
+ hidden_states=all_hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The FALCON_MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """
+)
+class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.backbone = FalconMambaModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.backbone.get_input_embeddings()
+
+ def set_input_embeddings(self, new_embeddings):
+ return self.backbone.set_input_embeddings(new_embeddings)
+
+ def _update_model_kwargs_for_generation(
+ self, outputs: ModelOutput, model_kwargs: dict[str, Any], num_new_tokens: int = 1, **kwargs
+ ) -> dict[str, Any]:
+ model_kwargs["cache_params"] = outputs.get("cache_params", None)
+ if (
+ model_kwargs.get("use_cache", True)
+ and "cache_position" in model_kwargs
+ and model_kwargs["cache_position"] is not None
+ ):
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
+
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ model_kwargs["attention_mask"] = torch.cat(
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+ )
+
+ return model_kwargs
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ inputs_embeds=None,
+ use_cache=None,
+ cache_params: Optional[FalconMambaCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+ # Overwritten -- uses `cache_params` as opposed to `past_key_values`
+ model_inputs = {"input_ids": input_ids.contiguous()}
+ if use_cache and cache_params is None:
+ # we initialize the `cache_position` to full size of `conv_states` at prefill stage
+ # considering padding will be applied when input length is shorter, and truncation
+ # will be applied when it is longer, so it will be equivalent to always have it match
+ # the length of `cache_params.conv_states`, which is `config.conv_kernel`
+ cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device)
+ if inputs_embeds is not None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ max_batch_size = inputs_embeds.size(0)
+ else:
+ max_batch_size = input_ids.size(0)
+ cache_params = FalconMambaCache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype)
+
+ if use_cache and cache_position[0] > 0:
+ model_inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1).contiguous()
+ attention_mask = None
+
+ if not use_cache and inputs_embeds is not None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+
+ model_inputs.update(
+ {
+ "cache_params": cache_params,
+ "use_cache": use_cache,
+ "cache_position": cache_position,
+ "attention_mask": attention_mask,
+ }
+ )
+
+ # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
+ for key, value in kwargs.items():
+ if key not in model_inputs:
+ model_inputs[key] = value
+
+ return model_inputs
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_params: Optional[FalconMambaCache] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs, # for now we need this for generation
+ ) -> Union[tuple, FalconMambaCausalLMOutput]:
+ r"""
+ cache_params (`FalconMambaCache`, *optional*):
+ If passed along, the model uses the previous state in all the blocks (which will give the output for the
+ `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ use_cache (`bool`, *optional*):
+ If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ falcon_mamba_outputs = self.backbone(
+ input_ids,
+ cache_params=cache_params,
+ inputs_embeds=inputs_embeds,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ attention_mask=attention_mask,
+ )
+ hidden_states = falcon_mamba_outputs[0]
+
+ logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + falcon_mamba_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return FalconMambaCausalLMOutput(
+ loss=loss,
+ logits=logits,
+ cache_params=falcon_mamba_outputs.cache_params,
+ hidden_states=falcon_mamba_outputs.hidden_states,
+ )
+
+
+__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel", "FalconMambaCache"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/modular_falcon_mamba.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/modular_falcon_mamba.py
new file mode 100644
index 0000000000000000000000000000000000000000..6df2be3a2652cf47100b82aa69be3c2554ba4161
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/modular_falcon_mamba.py
@@ -0,0 +1,582 @@
+# coding=utf-8
+# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch FALCONMAMBA model."""
+
+from typing import Optional
+
+import torch
+from torch import nn
+
+from ...utils import auto_docstring, logging
+from ...utils.import_utils import (
+ is_mamba_ssm_available,
+ is_mambapy_available,
+)
+from ..mamba.configuration_mamba import MambaConfig
+from ..mamba.modeling_mamba import (
+ MambaBlock,
+ MambaCache,
+ MambaCausalLMOutput,
+ MambaForCausalLM,
+ MambaMixer,
+ MambaModel,
+ MambaOutput,
+ MambaPreTrainedModel,
+ MambaRMSNorm,
+ _lazy_load_causal_conv1d,
+)
+
+
+logger = logging.get_logger(__name__)
+
+if is_mambapy_available():
+ from mambapy.pscan import pscan
+else:
+ pscan = None
+
+if is_mamba_ssm_available():
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
+
+ from ...kernels.falcon_mamba import mamba_inner_fn
+else:
+ selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
+
+_causal_conv1d_cache = None
+
+
+class FalconMambaConfig(MambaConfig):
+ """
+ This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the FALCON_MAMBA
+ [tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50280):
+ Vocabulary size of the FALCON_MAMBA model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`FalconMambaModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the embeddings and hidden states.
+ state_size (`int`, *optional*, defaults to 16): shape of the state space latents.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the model.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
+ The epsilon to use in the layer normalization layers.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 0):
+ The id of the beginning of sentence token in the vocabulary.
+ eos_token_id (`int`, *optional*, defaults to 0):
+ The id of the end of sentence token in the vocabulary.
+ expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
+ conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
+ use_bias (`bool`, *optional*, defaults to `False`):
+ Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
+ use_conv_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not to use bias in the convolution layer of the mixer block.
+ hidden_act (`str`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ initializer_range (`float`, *optional*, defaults to 0.1):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ residual_in_fp32 (`bool`, *optional*, defaults to `True`):
+ Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
+ time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
+ Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
+ time_step_scale (`float`, *optional*, defaults to 1.0):
+ Scale used used to scale `dt_proj.bias`.
+ time_step_min (`float`, *optional*, defaults to 0.001):
+ Minimum `time_step` used to bound `dt_proj.bias`.
+ time_step_max (`float`, *optional*, defaults to 0.1):
+ Maximum `time_step` used to bound `dt_proj.bias`.
+ time_step_init_scheme (`float`, *optional*, defaults to `"random"`):
+ Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`
+ time_step_floor (`float`, *optional*, defaults to 0.0001):
+ Minimum clamping value of the `dt_proj.bias` layer initialization.
+ rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
+ Whether or not to rescale `out_proj` weights when initializing.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the cache should be used.
+ use_falcon_mambapy (`bool`, *optional*, defaults to `False`):
+ This argument corresponds to `use_mambapy` in MambaConfig.
+ Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not available. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
+ mixer_rms_eps (`float`, *optional*, defaults to 1e-06):
+ The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states.
+
+
+ Example:
+
+ ```python
+ >>> from transformers import FalconMambaConfig, FalconMambaModel
+
+ >>> # Initializing a FalconMamba configuration
+ >>> configuration = FalconMambaConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = FalconMambaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ def __init__(
+ self,
+ vocab_size=50280,
+ hidden_size=768,
+ state_size=16,
+ num_hidden_layers=32,
+ layer_norm_epsilon=1e-5,
+ pad_token_id=0,
+ bos_token_id=0,
+ eos_token_id=0,
+ expand=2,
+ conv_kernel=4,
+ use_bias=False,
+ use_conv_bias=True,
+ hidden_act="silu",
+ initializer_range=0.1,
+ residual_in_fp32=True,
+ time_step_rank="auto",
+ time_step_scale=1.0,
+ time_step_min=0.001,
+ time_step_max=0.1,
+ time_step_init_scheme="random",
+ time_step_floor=1e-4,
+ rescale_prenorm_residual=False,
+ use_cache=True,
+ use_falcon_mambapy=False,
+ mixer_rms_eps=1e-6,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ hidden_size=hidden_size,
+ state_size=state_size,
+ num_hidden_layers=num_hidden_layers,
+ layer_norm_epsilon=layer_norm_epsilon,
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ expand=expand,
+ conv_kernel=conv_kernel,
+ use_bias=use_bias,
+ use_conv_bias=use_conv_bias,
+ hidden_act=hidden_act,
+ initializer_range=initializer_range,
+ residual_in_fp32=residual_in_fp32,
+ time_step_rank=time_step_rank,
+ time_step_scale=time_step_scale,
+ time_step_min=time_step_min,
+ time_step_max=time_step_max,
+ time_step_init_scheme=time_step_init_scheme,
+ time_step_floor=time_step_floor,
+ rescale_prenorm_residual=rescale_prenorm_residual,
+ use_cache=use_cache,
+ use_falcon_mambapy=use_falcon_mambapy,
+ **kwargs,
+ )
+ self.mixer_rms_eps = mixer_rms_eps
+ # This is needed since mamba overrides the intermediate_size attribute
+ self.intermediate_size = (
+ int(expand * self.hidden_size)
+ if kwargs.get("intermediate_size") is None
+ else kwargs.get("intermediate_size")
+ )
+
+
+class FalconMambaCache(MambaCache):
+ """
+ Cache for falcon_mamba model which does not have attention mechanism and key value states.
+
+ Arguments:
+ config (`PretrainedConfig):
+ The configuration file defining the shape-related attributes required to initialize the static cache.
+ max_batch_size (`int`):
+ The maximum batch size with which the model will be used. Note that a new instance must be instantiated if
+ a smaller batch size is used.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
+ The default `dtype` to use when initializing the layer.
+ device (`torch.device` or `str`, *optional*):
+ The device on which the cache should be initialized. Should be the same as the layer.
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache
+
+ >>> model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b")
+
+ >>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> cache_params = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
+ >>> cache_position = torch.arange(len(inputs["input_ids"][0]), device=model.device) # sequence length
+ >>> outputs = model(**inputs, cache_params=cache_params, cache_position=cache_position, use_cache=True)
+ >>> outputs.cache_params
+ ```
+ """
+
+ pass
+
+
+def rms_forward(hidden_states, variance_epsilon=1e-6):
+ """
+ Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
+ leverage this in order to multiply the final result with the RMSNorm weight
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Hidden states to normalize
+ variance_epsilon (`float`):
+ The eps value to add in the square root scaling factor
+ """
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
+ return hidden_states.to(input_dtype)
+
+
+class FalconMambaMixer(MambaMixer):
+ def warn_slow_implementation(self):
+ causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
+ is_fast_path_available = all(
+ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
+ )
+ if not is_fast_path_available:
+ if self.use_falcon_mambapy:
+ if is_mambapy_available():
+ logger.warning_once(
+ "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
+ " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
+ " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d"
+ )
+ else:
+ raise ImportError(
+ "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
+ )
+ else:
+ logger.warning_once(
+ "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
+ " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
+ " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
+ )
+
+ def __init__(self, config: FalconMambaConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
+ self.register_buffer(
+ "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
+ )
+ self.register_buffer(
+ "dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False
+ )
+ self.rms_eps = config.mixer_rms_eps
+
+ def cuda_kernels_forward(
+ self,
+ hidden_states: torch.Tensor,
+ cache_params: Optional[FalconMambaCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ # 1. Gated MLP's linear projection
+ projected_states = self.in_proj(hidden_states).transpose(1, 2)
+
+ if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
+ contextualized_states = mamba_inner_fn(
+ projected_states,
+ self.conv1d.weight,
+ self.conv1d.bias if self.use_conv_bias else None,
+ self.x_proj.weight,
+ self.dt_proj.weight,
+ self.out_proj.weight,
+ self.out_proj.bias.float() if self.use_bias else None,
+ -torch.exp(self.A_log.float()),
+ None, # input-dependent B
+ None, # input-dependent C
+ self.D.float(),
+ delta_bias=self.dt_proj.bias.float(),
+ delta_softplus=True,
+ b_rms_weight=self.b_c_rms,
+ c_rms_weight=self.b_c_rms,
+ dt_rms_weight=self.dt_rms,
+ b_c_dt_rms_eps=self.rms_eps,
+ )
+
+ else:
+ causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
+ hidden_states, gate = projected_states.chunk(2, dim=1)
+
+ if attention_mask is not None:
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+ # 2. Convolution sequence transformation
+ conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
+ if cache_params is not None and cache_position[0] > 0:
+ hidden_states = causal_conv1d_update(
+ hidden_states.squeeze(-1),
+ cache_params.conv_states[self.layer_idx],
+ conv_weights,
+ self.conv1d.bias,
+ self.activation,
+ )
+ hidden_states = hidden_states.unsqueeze(-1)
+ else:
+ if cache_params is not None:
+ conv_states = nn.functional.pad(
+ hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
+ )
+ cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
+ hidden_states = causal_conv1d_fn(
+ hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
+ )
+
+ if attention_mask is not None:
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+ # 3. State Space Model sequence transformation
+ # 3.a. input varying initialization of time_step, B and C
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
+ time_step, B, C = torch.split(
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
+ )
+
+ B = rms_forward(B, variance_epsilon=self.rms_eps)
+ C = rms_forward(C, variance_epsilon=self.rms_eps)
+ time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
+
+ # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
+ # at the price of a small overhead.
+ if hasattr(self.config, "_pre_quantization_dtype"):
+ discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
+ else:
+ discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
+
+ A = -torch.exp(self.A_log.float())
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
+ time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
+ if cache_params is not None and cache_position[0] > 0:
+ scan_outputs = selective_state_update(
+ cache_params.ssm_states[self.layer_idx],
+ hidden_states[..., 0],
+ discrete_time_step[..., 0],
+ A,
+ B[:, 0],
+ C[:, 0],
+ self.D,
+ gate[..., 0],
+ time_proj_bias,
+ dt_softplus=True,
+ ).unsqueeze(-1)
+ else:
+ scan_outputs, ssm_state = selective_scan_fn(
+ hidden_states,
+ discrete_time_step,
+ A,
+ B.transpose(1, 2),
+ C.transpose(1, 2),
+ self.D.float(),
+ gate,
+ time_proj_bias,
+ delta_softplus=True,
+ return_last_state=True,
+ )
+ if ssm_state is not None and cache_params is not None:
+ cache_params.update_ssm_state(self.layer_idx, ssm_state)
+
+ # 4. Final linear projection
+ contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
+ return contextualized_states
+
+ def slow_forward(
+ self,
+ input_states,
+ cache_params: Optional[FalconMambaCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ batch_size, seq_len, _ = input_states.shape
+ dtype = input_states.dtype
+ # 1. Gated MLP's linear projection
+ projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
+ hidden_states, gate = projected_states.chunk(2, dim=1)
+
+ if attention_mask is not None:
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+ # 2. Convolution sequence transformation
+ if cache_params is not None:
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
+ ssm_state = ssm_state.to(hidden_states.device)
+ # use `cache_position.shape[0]` to check whether we are in prefill
+ # stage, it's equivalent to check `cache_position[0] == 0`, which
+ # breaks dynamo fullgraph constraints
+ if cache_position is not None and cache_position.shape[0] == self.conv_kernel_size:
+ conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
+
+ cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
+ hidden_states = self.act(
+ self.conv1d(hidden_states)[..., :seq_len]
+ ) # [batch, intermediate_size, seq_len]
+ else:
+ conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
+ conv_state = conv_state.to(self.conv1d.weight.device)
+ hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
+ if self.use_conv_bias:
+ hidden_states += self.conv1d.bias
+ hidden_states = (
+ self.act(hidden_states).to(dtype).unsqueeze(-1)
+ ) # [batch, intermediate_size, 1] : decoding
+ else:
+ ssm_state = torch.zeros(
+ (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype
+ )
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
+
+ if attention_mask is not None:
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+ # 3. State Space Model sequence transformation
+ # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
+ time_step, B, C = torch.split(
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
+ )
+
+ B = rms_forward(B, variance_epsilon=self.rms_eps)
+ C = rms_forward(C, variance_epsilon=self.rms_eps)
+ time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
+
+ discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
+ discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(
+ 1, 2
+ ) # [batch, intermediate_size, seq_len]
+
+ # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
+ A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
+ discrete_A = torch.exp(
+ A[None, :, None, :] * discrete_time_step[:, :, :, None]
+ ) # [batch, intermediate_size, seq_len, ssm_state_size]
+ discrete_B = (
+ discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
+ ) # [batch, intermediate_size, seq_len, ssm_state_size]
+ deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
+
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
+ if self.use_falcon_mambapy and self.training and cache_params is None:
+ hs = pscan(
+ discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)
+ ) # [batch, seq_len, intermediate_size, ssm_state_size]
+ scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len]
+ scan_output = scan_output + hidden_states * self.D[None, :, None]
+ scan_output = scan_output * self.act(gate)
+ else:
+ scan_outputs = []
+ for i in range(seq_len):
+ ssm_state = (
+ discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
+ ) # [batch, intermediate_size, ssm_state]
+ scan_output = torch.matmul(
+ ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)
+ ) # [batch, intermediate_size, 1]
+ scan_outputs.append(scan_output[:, :, 0])
+ scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
+ scan_output = scan_output + (hidden_states * self.D[None, :, None])
+ scan_output = scan_output * self.act(gate)
+
+ if cache_params is not None:
+ cache_params.update_ssm_state(self.layer_idx, ssm_state)
+
+ # 4. Final linear projection
+ contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
+ return contextualized_states
+
+ def forward(
+ self,
+ hidden_states,
+ cache_params: Optional[FalconMambaCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
+ is_fast_path_available = all(
+ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
+ )
+ if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
+ return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
+ return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
+
+
+class FalconMambaRMSNorm(MambaRMSNorm):
+ def forward(self, hidden_states):
+ return self.weight.to(hidden_states.device) * rms_forward(
+ hidden_states, variance_epsilon=self.variance_epsilon
+ )
+
+
+class FalconMambaBlock(MambaBlock):
+ pass
+
+
+@auto_docstring
+class FalconMambaPreTrainedModel(MambaPreTrainedModel):
+ pass
+
+
+class FalconMambaOutput(MambaOutput):
+ pass
+
+
+class FalconMambaCausalLMOutput(MambaCausalLMOutput):
+ pass
+
+
+class FalconMambaModel(MambaModel, FalconMambaPreTrainedModel):
+ def __init__(self, config):
+ FalconMambaPreTrainedModel.__init__(self, config)
+
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.layers = nn.ModuleList(
+ [FalconMambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]
+ )
+
+ self.gradient_checkpointing = False
+ self.norm_f = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def load_hook(self, state_dict, prefix, *args):
+ raise AttributeError("Not needed for FalconMamba")
+
+
+class FalconMambaForCausalLM(MambaForCausalLM):
+ pass
+
+
+__all__ = [
+ "FalconMambaForCausalLM",
+ "FalconMambaModel",
+ "FalconMambaPreTrainedModel",
+ "FalconMambaCache",
+ "FalconMambaConfig",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..44d1ec7236310774ed6b1379683c144d7f93ecce
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_fastspeech2_conformer import *
+ from .modeling_fastspeech2_conformer import *
+ from .tokenization_fastspeech2_conformer import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..89d65a261c64fbabe493aa37677aaacb6f226a3b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py
@@ -0,0 +1,480 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""FastSpeech2Conformer model configuration"""
+
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class FastSpeech2ConformerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`FastSpeech2ConformerModel`]. It is used to
+ instantiate a FastSpeech2Conformer model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the
+ FastSpeech2Conformer [espnet/fastspeech2_conformer](https://huggingface.co/espnet/fastspeech2_conformer)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 384):
+ The dimensionality of the hidden layers.
+ vocab_size (`int`, *optional*, defaults to 78):
+ The size of the vocabulary.
+ num_mel_bins (`int`, *optional*, defaults to 80):
+ The number of mel filters used in the filter bank.
+ encoder_num_attention_heads (`int`, *optional*, defaults to 2):
+ The number of attention heads in the encoder.
+ encoder_layers (`int`, *optional*, defaults to 4):
+ The number of layers in the encoder.
+ encoder_linear_units (`int`, *optional*, defaults to 1536):
+ The number of units in the linear layer of the encoder.
+ decoder_layers (`int`, *optional*, defaults to 4):
+ The number of layers in the decoder.
+ decoder_num_attention_heads (`int`, *optional*, defaults to 2):
+ The number of attention heads in the decoder.
+ decoder_linear_units (`int`, *optional*, defaults to 1536):
+ The number of units in the linear layer of the decoder.
+ speech_decoder_postnet_layers (`int`, *optional*, defaults to 5):
+ The number of layers in the post-net of the speech decoder.
+ speech_decoder_postnet_units (`int`, *optional*, defaults to 256):
+ The number of units in the post-net layers of the speech decoder.
+ speech_decoder_postnet_kernel (`int`, *optional*, defaults to 5):
+ The kernel size in the post-net of the speech decoder.
+ positionwise_conv_kernel_size (`int`, *optional*, defaults to 3):
+ The size of the convolution kernel used in the position-wise layer.
+ encoder_normalize_before (`bool`, *optional*, defaults to `False`):
+ Specifies whether to normalize before encoder layers.
+ decoder_normalize_before (`bool`, *optional*, defaults to `False`):
+ Specifies whether to normalize before decoder layers.
+ encoder_concat_after (`bool`, *optional*, defaults to `False`):
+ Specifies whether to concatenate after encoder layers.
+ decoder_concat_after (`bool`, *optional*, defaults to `False`):
+ Specifies whether to concatenate after decoder layers.
+ reduction_factor (`int`, *optional*, defaults to 1):
+ The factor by which the speech frame rate is reduced.
+ speaking_speed (`float`, *optional*, defaults to 1.0):
+ The speed of the speech produced.
+ use_macaron_style_in_conformer (`bool`, *optional*, defaults to `True`):
+ Specifies whether to use macaron style in the conformer.
+ use_cnn_in_conformer (`bool`, *optional*, defaults to `True`):
+ Specifies whether to use convolutional neural networks in the conformer.
+ encoder_kernel_size (`int`, *optional*, defaults to 7):
+ The kernel size used in the encoder.
+ decoder_kernel_size (`int`, *optional*, defaults to 31):
+ The kernel size used in the decoder.
+ duration_predictor_layers (`int`, *optional*, defaults to 2):
+ The number of layers in the duration predictor.
+ duration_predictor_channels (`int`, *optional*, defaults to 256):
+ The number of channels in the duration predictor.
+ duration_predictor_kernel_size (`int`, *optional*, defaults to 3):
+ The kernel size used in the duration predictor.
+ energy_predictor_layers (`int`, *optional*, defaults to 2):
+ The number of layers in the energy predictor.
+ energy_predictor_channels (`int`, *optional*, defaults to 256):
+ The number of channels in the energy predictor.
+ energy_predictor_kernel_size (`int`, *optional*, defaults to 3):
+ The kernel size used in the energy predictor.
+ energy_predictor_dropout (`float`, *optional*, defaults to 0.5):
+ The dropout rate in the energy predictor.
+ energy_embed_kernel_size (`int`, *optional*, defaults to 1):
+ The kernel size used in the energy embed layer.
+ energy_embed_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout rate in the energy embed layer.
+ stop_gradient_from_energy_predictor (`bool`, *optional*, defaults to `False`):
+ Specifies whether to stop gradients from the energy predictor.
+ pitch_predictor_layers (`int`, *optional*, defaults to 5):
+ The number of layers in the pitch predictor.
+ pitch_predictor_channels (`int`, *optional*, defaults to 256):
+ The number of channels in the pitch predictor.
+ pitch_predictor_kernel_size (`int`, *optional*, defaults to 5):
+ The kernel size used in the pitch predictor.
+ pitch_predictor_dropout (`float`, *optional*, defaults to 0.5):
+ The dropout rate in the pitch predictor.
+ pitch_embed_kernel_size (`int`, *optional*, defaults to 1):
+ The kernel size used in the pitch embed layer.
+ pitch_embed_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout rate in the pitch embed layer.
+ stop_gradient_from_pitch_predictor (`bool`, *optional*, defaults to `True`):
+ Specifies whether to stop gradients from the pitch predictor.
+ encoder_dropout_rate (`float`, *optional*, defaults to 0.2):
+ The dropout rate in the encoder.
+ encoder_positional_dropout_rate (`float`, *optional*, defaults to 0.2):
+ The positional dropout rate in the encoder.
+ encoder_attention_dropout_rate (`float`, *optional*, defaults to 0.2):
+ The attention dropout rate in the encoder.
+ decoder_dropout_rate (`float`, *optional*, defaults to 0.2):
+ The dropout rate in the decoder.
+ decoder_positional_dropout_rate (`float`, *optional*, defaults to 0.2):
+ The positional dropout rate in the decoder.
+ decoder_attention_dropout_rate (`float`, *optional*, defaults to 0.2):
+ The attention dropout rate in the decoder.
+ duration_predictor_dropout_rate (`float`, *optional*, defaults to 0.2):
+ The dropout rate in the duration predictor.
+ speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5):
+ The dropout rate in the speech decoder postnet.
+ max_source_positions (`int`, *optional*, defaults to 5000):
+ if `"relative"` position embeddings are used, defines the maximum source input positions.
+ use_masking (`bool`, *optional*, defaults to `True`):
+ Specifies whether to use masking in the model.
+ use_weighted_masking (`bool`, *optional*, defaults to `False`):
+ Specifies whether to use weighted masking in the model.
+ num_speakers (`int`, *optional*):
+ Number of speakers. If set to > 1, assume that the speaker ids will be provided as the input and use
+ speaker id embedding layer.
+ num_languages (`int`, *optional*):
+ Number of languages. If set to > 1, assume that the language ids will be provided as the input and use the
+ language id embedding layer.
+ speaker_embed_dim (`int`, *optional*):
+ Speaker embedding dimension. If set to > 0, assume that speaker_embedding will be provided as the input.
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+ Specifies whether the model is an encoder-decoder.
+
+ Example:
+
+ ```python
+ >>> from transformers import FastSpeech2ConformerModel, FastSpeech2ConformerConfig
+
+ >>> # Initializing a FastSpeech2Conformer style configuration
+ >>> configuration = FastSpeech2ConformerConfig()
+
+ >>> # Initializing a model from the FastSpeech2Conformer style configuration
+ >>> model = FastSpeech2ConformerModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "fastspeech2_conformer"
+ base_config_key = "model_config"
+ attribute_map = {"num_hidden_layers": "encoder_layers", "num_attention_heads": "encoder_num_attention_heads"}
+
+ def __init__(
+ self,
+ hidden_size=384,
+ vocab_size=78,
+ num_mel_bins=80,
+ encoder_num_attention_heads=2,
+ encoder_layers=4,
+ encoder_linear_units=1536,
+ decoder_layers=4,
+ decoder_num_attention_heads=2,
+ decoder_linear_units=1536,
+ speech_decoder_postnet_layers=5,
+ speech_decoder_postnet_units=256,
+ speech_decoder_postnet_kernel=5,
+ positionwise_conv_kernel_size=3,
+ encoder_normalize_before=False,
+ decoder_normalize_before=False,
+ encoder_concat_after=False,
+ decoder_concat_after=False,
+ reduction_factor=1,
+ speaking_speed=1.0,
+ use_macaron_style_in_conformer=True,
+ use_cnn_in_conformer=True,
+ encoder_kernel_size=7,
+ decoder_kernel_size=31,
+ duration_predictor_layers=2,
+ duration_predictor_channels=256,
+ duration_predictor_kernel_size=3,
+ energy_predictor_layers=2,
+ energy_predictor_channels=256,
+ energy_predictor_kernel_size=3,
+ energy_predictor_dropout=0.5,
+ energy_embed_kernel_size=1,
+ energy_embed_dropout=0.0,
+ stop_gradient_from_energy_predictor=False,
+ pitch_predictor_layers=5,
+ pitch_predictor_channels=256,
+ pitch_predictor_kernel_size=5,
+ pitch_predictor_dropout=0.5,
+ pitch_embed_kernel_size=1,
+ pitch_embed_dropout=0.0,
+ stop_gradient_from_pitch_predictor=True,
+ encoder_dropout_rate=0.2,
+ encoder_positional_dropout_rate=0.2,
+ encoder_attention_dropout_rate=0.2,
+ decoder_dropout_rate=0.2,
+ decoder_positional_dropout_rate=0.2,
+ decoder_attention_dropout_rate=0.2,
+ duration_predictor_dropout_rate=0.2,
+ speech_decoder_postnet_dropout=0.5,
+ max_source_positions=5000,
+ use_masking=True,
+ use_weighted_masking=False,
+ num_speakers=None,
+ num_languages=None,
+ speaker_embed_dim=None,
+ is_encoder_decoder=True,
+ **kwargs,
+ ):
+ if positionwise_conv_kernel_size % 2 == 0:
+ raise ValueError(
+ f"positionwise_conv_kernel_size must be odd, but got {positionwise_conv_kernel_size} instead."
+ )
+ if encoder_kernel_size % 2 == 0:
+ raise ValueError(f"encoder_kernel_size must be odd, but got {encoder_kernel_size} instead.")
+ if decoder_kernel_size % 2 == 0:
+ raise ValueError(f"decoder_kernel_size must be odd, but got {decoder_kernel_size} instead.")
+ if duration_predictor_kernel_size % 2 == 0:
+ raise ValueError(
+ f"duration_predictor_kernel_size must be odd, but got {duration_predictor_kernel_size} instead."
+ )
+ if energy_predictor_kernel_size % 2 == 0:
+ raise ValueError(
+ f"energy_predictor_kernel_size must be odd, but got {energy_predictor_kernel_size} instead."
+ )
+ if energy_embed_kernel_size % 2 == 0:
+ raise ValueError(f"energy_embed_kernel_size must be odd, but got {energy_embed_kernel_size} instead.")
+ if pitch_predictor_kernel_size % 2 == 0:
+ raise ValueError(
+ f"pitch_predictor_kernel_size must be odd, but got {pitch_predictor_kernel_size} instead."
+ )
+ if pitch_embed_kernel_size % 2 == 0:
+ raise ValueError(f"pitch_embed_kernel_size must be odd, but got {pitch_embed_kernel_size} instead.")
+ if hidden_size % encoder_num_attention_heads != 0:
+ raise ValueError("The hidden_size must be evenly divisible by encoder_num_attention_heads.")
+ if hidden_size % decoder_num_attention_heads != 0:
+ raise ValueError("The hidden_size must be evenly divisible by decoder_num_attention_heads.")
+ if use_masking and use_weighted_masking:
+ raise ValueError("Either use_masking or use_weighted_masking can be True, but not both.")
+
+ self.hidden_size = hidden_size
+ self.vocab_size = vocab_size
+ self.num_mel_bins = num_mel_bins
+ self.encoder_config = {
+ "num_attention_heads": encoder_num_attention_heads,
+ "layers": encoder_layers,
+ "kernel_size": encoder_kernel_size,
+ "attention_dropout_rate": encoder_attention_dropout_rate,
+ "dropout_rate": encoder_dropout_rate,
+ "positional_dropout_rate": encoder_positional_dropout_rate,
+ "linear_units": encoder_linear_units,
+ "normalize_before": encoder_normalize_before,
+ "concat_after": encoder_concat_after,
+ }
+ self.decoder_config = {
+ "num_attention_heads": decoder_num_attention_heads,
+ "layers": decoder_layers,
+ "kernel_size": decoder_kernel_size,
+ "attention_dropout_rate": decoder_attention_dropout_rate,
+ "dropout_rate": decoder_dropout_rate,
+ "positional_dropout_rate": decoder_positional_dropout_rate,
+ "linear_units": decoder_linear_units,
+ "normalize_before": decoder_normalize_before,
+ "concat_after": decoder_concat_after,
+ }
+ self.encoder_num_attention_heads = encoder_num_attention_heads
+ self.encoder_layers = encoder_layers
+ self.duration_predictor_channels = duration_predictor_channels
+ self.duration_predictor_kernel_size = duration_predictor_kernel_size
+ self.duration_predictor_layers = duration_predictor_layers
+ self.energy_embed_dropout = energy_embed_dropout
+ self.energy_embed_kernel_size = energy_embed_kernel_size
+ self.energy_predictor_channels = energy_predictor_channels
+ self.energy_predictor_dropout = energy_predictor_dropout
+ self.energy_predictor_kernel_size = energy_predictor_kernel_size
+ self.energy_predictor_layers = energy_predictor_layers
+ self.pitch_embed_dropout = pitch_embed_dropout
+ self.pitch_embed_kernel_size = pitch_embed_kernel_size
+ self.pitch_predictor_channels = pitch_predictor_channels
+ self.pitch_predictor_dropout = pitch_predictor_dropout
+ self.pitch_predictor_kernel_size = pitch_predictor_kernel_size
+ self.pitch_predictor_layers = pitch_predictor_layers
+ self.positionwise_conv_kernel_size = positionwise_conv_kernel_size
+ self.speech_decoder_postnet_units = speech_decoder_postnet_units
+ self.speech_decoder_postnet_dropout = speech_decoder_postnet_dropout
+ self.speech_decoder_postnet_kernel = speech_decoder_postnet_kernel
+ self.speech_decoder_postnet_layers = speech_decoder_postnet_layers
+ self.reduction_factor = reduction_factor
+ self.speaking_speed = speaking_speed
+ self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
+ self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
+ self.max_source_positions = max_source_positions
+ self.use_cnn_in_conformer = use_cnn_in_conformer
+ self.use_macaron_style_in_conformer = use_macaron_style_in_conformer
+ self.use_masking = use_masking
+ self.use_weighted_masking = use_weighted_masking
+ self.num_speakers = num_speakers
+ self.num_languages = num_languages
+ self.speaker_embed_dim = speaker_embed_dim
+ self.duration_predictor_dropout_rate = duration_predictor_dropout_rate
+ self.is_encoder_decoder = is_encoder_decoder
+
+ super().__init__(
+ is_encoder_decoder=is_encoder_decoder,
+ **kwargs,
+ )
+
+
+class FastSpeech2ConformerHifiGanConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`FastSpeech2ConformerHifiGanModel`]. It is used to
+ instantiate a FastSpeech2Conformer HiFi-GAN vocoder model according to the specified arguments, defining the model
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
+ FastSpeech2Conformer
+ [espnet/fastspeech2_conformer_hifigan](https://huggingface.co/espnet/fastspeech2_conformer_hifigan) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ model_in_dim (`int`, *optional*, defaults to 80):
+ The number of frequency bins in the input log-mel spectrogram.
+ upsample_initial_channel (`int`, *optional*, defaults to 512):
+ The number of input channels into the upsampling network.
+ upsample_rates (`tuple[int]` or `list[int]`, *optional*, defaults to `[8, 8, 2, 2]`):
+ A tuple of integers defining the stride of each 1D convolutional layer in the upsampling network. The
+ length of *upsample_rates* defines the number of convolutional layers and has to match the length of
+ *upsample_kernel_sizes*.
+ upsample_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[16, 16, 4, 4]`):
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the upsampling network. The
+ length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match the length of
+ *upsample_rates*.
+ resblock_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[3, 7, 11]`):
+ A tuple of integers defining the kernel sizes of the 1D convolutional layers in the multi-receptive field
+ fusion (MRF) module.
+ resblock_dilation_sizes (`tuple[tuple[int]]` or `list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
+ A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the
+ multi-receptive field fusion (MRF) module.
+ initializer_range (`float`, *optional*, defaults to 0.01):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ leaky_relu_slope (`float`, *optional*, defaults to 0.1):
+ The angle of the negative slope used by the leaky ReLU activation.
+ normalize_before (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the spectrogram before vocoding using the vocoder's learned mean and variance.
+
+ Example:
+
+ ```python
+ >>> from transformers import FastSpeech2ConformerHifiGan, FastSpeech2ConformerHifiGanConfig
+
+ >>> # Initializing a FastSpeech2ConformerHifiGan configuration
+ >>> configuration = FastSpeech2ConformerHifiGanConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = FastSpeech2ConformerHifiGan(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "hifigan"
+ base_config_key = "vocoder_config"
+
+ def __init__(
+ self,
+ model_in_dim=80,
+ upsample_initial_channel=512,
+ upsample_rates=[8, 8, 2, 2],
+ upsample_kernel_sizes=[16, 16, 4, 4],
+ resblock_kernel_sizes=[3, 7, 11],
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ initializer_range=0.01,
+ leaky_relu_slope=0.1,
+ normalize_before=True,
+ **kwargs,
+ ):
+ self.model_in_dim = model_in_dim
+ self.upsample_initial_channel = upsample_initial_channel
+ self.upsample_rates = upsample_rates
+ self.upsample_kernel_sizes = upsample_kernel_sizes
+ self.resblock_kernel_sizes = resblock_kernel_sizes
+ self.resblock_dilation_sizes = resblock_dilation_sizes
+ self.initializer_range = initializer_range
+ self.leaky_relu_slope = leaky_relu_slope
+ self.normalize_before = normalize_before
+ super().__init__(**kwargs)
+
+
+class FastSpeech2ConformerWithHifiGanConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`FastSpeech2ConformerWithHifiGan`]. It is used to
+ instantiate a `FastSpeech2ConformerWithHifiGanModel` model according to the specified sub-models configurations,
+ defining the model architecture.
+
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the
+ FastSpeech2ConformerModel [espnet/fastspeech2_conformer](https://huggingface.co/espnet/fastspeech2_conformer) and
+ FastSpeech2ConformerHifiGan
+ [espnet/fastspeech2_conformer_hifigan](https://huggingface.co/espnet/fastspeech2_conformer_hifigan) architectures.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ model_config (`typing.Dict`, *optional*):
+ Configuration of the text-to-speech model.
+ vocoder_config (`typing.Dict`, *optional*):
+ Configuration of the vocoder model.
+ model_config ([`FastSpeech2ConformerConfig`], *optional*):
+ Configuration of the text-to-speech model.
+ vocoder_config ([`FastSpeech2ConformerHiFiGanConfig`], *optional*):
+ Configuration of the vocoder model.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... FastSpeech2ConformerConfig,
+ ... FastSpeech2ConformerHifiGanConfig,
+ ... FastSpeech2ConformerWithHifiGanConfig,
+ ... FastSpeech2ConformerWithHifiGan,
+ ... )
+
+ >>> # Initializing FastSpeech2ConformerWithHifiGan sub-modules configurations.
+ >>> model_config = FastSpeech2ConformerConfig()
+ >>> vocoder_config = FastSpeech2ConformerHifiGanConfig()
+
+ >>> # Initializing a FastSpeech2ConformerWithHifiGan module style configuration
+ >>> configuration = FastSpeech2ConformerWithHifiGanConfig(model_config.to_dict(), vocoder_config.to_dict())
+
+ >>> # Initializing a model (with random weights)
+ >>> model = FastSpeech2ConformerWithHifiGan(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "fastspeech2_conformer_with_hifigan"
+ sub_configs = {"model_config": FastSpeech2ConformerConfig, "vocoder_config": FastSpeech2ConformerHifiGanConfig}
+
+ def __init__(
+ self,
+ model_config: Optional[dict] = None,
+ vocoder_config: Optional[dict] = None,
+ **kwargs,
+ ):
+ if model_config is None:
+ model_config = {}
+ logger.info("model_config is None. initializing the model with default values.")
+
+ if vocoder_config is None:
+ vocoder_config = {}
+ logger.info("vocoder_config is None. initializing the coarse model with default values.")
+
+ self.model_config = FastSpeech2ConformerConfig(**model_config)
+ self.vocoder_config = FastSpeech2ConformerHifiGanConfig(**vocoder_config)
+
+ super().__init__(**kwargs)
+
+
+__all__ = ["FastSpeech2ConformerConfig", "FastSpeech2ConformerHifiGanConfig", "FastSpeech2ConformerWithHifiGanConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a2dc39385b3c953eb256bb03c04d6456c8f8890
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py
@@ -0,0 +1,1588 @@
+# coding=utf-8
+# Copyright 2023 The Espnet authors, IMS Toucan authors, and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch FastSpeech2Conformer model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import ModelOutput, auto_docstring, logging
+from .configuration_fastspeech2_conformer import (
+ FastSpeech2ConformerConfig,
+ FastSpeech2ConformerHifiGanConfig,
+ FastSpeech2ConformerWithHifiGanConfig,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`FastSpeech2ConformerModel`].
+ """
+)
+class FastSpeech2ConformerModelOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Spectrogram generation loss.
+ duration_outputs (`torch.LongTensor` of shape `(batch_size, max_text_length + 1)`, *optional*):
+ Outputs of the duration predictor.
+ pitch_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
+ Outputs of the pitch predictor.
+ energy_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
+ Outputs of the energy predictor.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ spectrogram: Optional[torch.FloatTensor] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ duration_outputs: Optional[torch.LongTensor] = None
+ pitch_outputs: Optional[torch.FloatTensor] = None
+ energy_outputs: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`FastSpeech2ConformerWithHifiGan`].
+ """
+)
+class FastSpeech2ConformerWithHifiGanOutput(FastSpeech2ConformerModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Spectrogram generation loss.
+ duration_outputs (`torch.LongTensor` of shape `(batch_size, max_text_length + 1)`, *optional*):
+ Outputs of the duration predictor.
+ pitch_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
+ Outputs of the pitch predictor.
+ energy_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
+ Outputs of the energy predictor.
+ waveform (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
+ Speech output as a result of passing the predicted mel spectrogram through the vocoder.
+ """
+
+ waveform: Optional[torch.FloatTensor] = None
+
+
+def length_regulator(encoded_embeddings, duration_labels, speaking_speed=1.0):
+ """
+ Length regulator for feed-forward Transformer.
+
+ This is the length regulator module described in `FastSpeech: Fast, Robust and Controllable Text to Speech`
+ https://huggingface.co/papers/1905.09263. The length regulator expands char or phoneme-level embedding features to
+ frame-level by repeating each feature based on the corresponding predicted durations.
+
+ Args:
+ encoded_embeddings (`torch.Tensor` of shape `(batch_size, max_text_length, embedding_dim)`):
+ Batch of sequences of char or phoneme embeddings.
+ duration_labels (`torch.LongTensor` of shape `(batch_size, time)`):
+ Batch of durations of each frame.
+ speaking_speed (`float`, *optional*, defaults to 1.0):
+ Value to control speed of speech.
+
+ Returns:
+ `torch.Tensor`:
+ Replicated input tensor based on durations (batch_size, time*, embedding_dim).
+ """
+
+ if speaking_speed <= 0:
+ raise ValueError("`speaking_speed` must be greater than 0.")
+ elif speaking_speed != 1.0:
+ duration_labels = torch.round(duration_labels.float() * speaking_speed).long()
+
+ if duration_labels.sum() == 0:
+ duration_labels[duration_labels.sum(dim=1).eq(0)] = 1
+
+ # Calculate the maximum length needed
+ max_len = torch.sum(duration_labels, dim=1).max()
+
+ # Create a padded tensor to hold the results
+ hidden_states = torch.zeros(
+ (encoded_embeddings.size(0), max_len, encoded_embeddings.size(2)),
+ dtype=torch.float,
+ device=encoded_embeddings.device,
+ )
+
+ # Loop through the batch and fill in the data
+ for i, (encoded_embedding, target_duration) in enumerate(zip(encoded_embeddings, duration_labels)):
+ repeated = torch.repeat_interleave(encoded_embedding, target_duration, dim=0)
+ hidden_states[i, : repeated.size(0)] = repeated
+
+ return hidden_states
+
+
+class FastSpeech2ConformerDurationPredictor(nn.Module):
+ """
+ Duration predictor module.
+
+ This is a module of duration predictor described in the paper 'FastSpeech: Fast, Robust and Controllable Text to
+ Speech' https://huggingface.co/papers/1905.09263 The duration predictor predicts a duration of each frame in log domain
+ from the hidden embeddings of encoder.
+
+ Note:
+ The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`, the
+ outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
+
+ """
+
+ def __init__(self, config: FastSpeech2ConformerConfig):
+ super().__init__()
+
+ self.conv_layers = nn.ModuleList()
+ self.log_domain_offset = 1.0
+
+ for layer_idx in range(config.duration_predictor_layers):
+ num_chans = config.duration_predictor_channels
+ input_channels = config.hidden_size if layer_idx == 0 else num_chans
+ layer = FastSpeech2ConformerPredictorLayer(
+ input_channels,
+ num_chans,
+ config.duration_predictor_kernel_size,
+ config.duration_predictor_dropout_rate,
+ )
+ self.conv_layers.append(layer)
+ self.linear = nn.Linear(config.duration_predictor_channels, 1)
+
+ def forward(self, encoder_hidden_states):
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch_size, max_text_length, input_dim)`):
+ Batch of input sequences.
+ padding_masks (`torch.ByteTensor` of shape `(batch_size, max_text_length)`, *optional*):
+ Batch of masks indicating padded part.
+
+ Returns:
+ `torch.Tensor`: Batch of predicted durations in log domain `(batch_size, max_text_length)`.
+
+ """
+ # (batch_size, input_dim, max_text_length)
+ hidden_states = encoder_hidden_states.transpose(1, -1)
+ for layer in self.conv_layers:
+ hidden_states = layer(hidden_states)
+
+ # NOTE: calculate in log domain, (batch_size, max_text_length)
+ hidden_states = self.linear(hidden_states.transpose(1, -1)).squeeze(-1)
+
+ if not self.training:
+ # NOTE: calculate in linear domain
+ hidden_states = torch.clamp(torch.round(hidden_states.exp() - self.log_domain_offset), min=0).long()
+
+ return hidden_states
+
+
+# Copied from transformers.models.speecht5.modeling_speecht5.SpeechT5BatchNormConvLayer
+class FastSpeech2ConformerBatchNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+
+ if layer_id == 0:
+ in_conv_dim = config.num_mel_bins
+ else:
+ in_conv_dim = config.speech_decoder_postnet_units
+
+ if layer_id == config.speech_decoder_postnet_layers - 1:
+ out_conv_dim = config.num_mel_bins
+ else:
+ out_conv_dim = config.speech_decoder_postnet_units
+
+ self.conv = nn.Conv1d(
+ in_conv_dim,
+ out_conv_dim,
+ kernel_size=config.speech_decoder_postnet_kernel,
+ stride=1,
+ padding=(config.speech_decoder_postnet_kernel - 1) // 2,
+ bias=False,
+ )
+ self.batch_norm = nn.BatchNorm1d(out_conv_dim)
+
+ if layer_id < config.speech_decoder_postnet_layers - 1:
+ self.activation = nn.Tanh()
+ else:
+ self.activation = None
+
+ self.dropout = nn.Dropout(config.speech_decoder_postnet_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.batch_norm(hidden_states)
+ if self.activation is not None:
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class FastSpeech2ConformerSpeechDecoderPostnet(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.feat_out = nn.Linear(config.hidden_size, config.num_mel_bins * config.reduction_factor)
+ self.layers = nn.ModuleList(
+ [FastSpeech2ConformerBatchNormConvLayer(config, i) for i in range(config.speech_decoder_postnet_layers)]
+ )
+
+ def forward(self, hidden_states: torch.Tensor):
+ outputs_before_postnet = self.feat_out(hidden_states).view(hidden_states.size(0), -1, self.config.num_mel_bins)
+ layer_output = outputs_before_postnet.transpose(1, 2)
+ for layer in self.layers:
+ layer_output = layer(layer_output)
+ outputs_after_postnet = outputs_before_postnet + layer_output.transpose(1, 2)
+ return outputs_before_postnet, outputs_after_postnet
+
+
+class FastSpeech2ConformerPredictorLayer(nn.Module):
+ def __init__(self, input_channels, num_chans, kernel_size, dropout_rate):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ input_channels,
+ num_chans,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ )
+ self.activation = nn.ReLU()
+ self.layer_norm = nn.LayerNorm(num_chans)
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ # Perform layer norm on dimension 1
+ hidden_states = hidden_states.transpose(1, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(1, -1)
+
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class FastSpeech2ConformerVariancePredictor(nn.Module):
+ def __init__(
+ self,
+ config: FastSpeech2ConformerConfig,
+ num_layers=2,
+ num_chans=384,
+ kernel_size=3,
+ dropout_rate=0.5,
+ ):
+ """
+ Initialize variance predictor module.
+
+ Args:
+ input_dim (`int`): Input dimension.
+ num_layers (`int`, *optional*, defaults to 2): Number of convolutional layers.
+ num_chans (`int`, *optional*, defaults to 384): Number of channels of convolutional layers.
+ kernel_size (`int`, *optional*, defaults to 3): Kernel size of convolutional layers.
+ dropout_rate (`float`, *optional*, defaults to 0.5): Dropout rate.
+ """
+ super().__init__()
+ self.conv_layers = nn.ModuleList()
+ for idx in range(num_layers):
+ input_channels = config.hidden_size if idx == 0 else num_chans
+ layer = FastSpeech2ConformerPredictorLayer(input_channels, num_chans, kernel_size, dropout_rate)
+ self.conv_layers.append(layer)
+ self.linear = nn.Linear(num_chans, 1)
+
+ def forward(self, encoder_hidden_states, padding_masks=None):
+ """
+ Calculate forward propagation.
+
+ Args:
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, max_text_length, input_dim)`):
+ Batch of input sequences.
+ padding_masks (`torch.ByteTensor` of shape `(batch_size, max_text_length)`, *optional*):
+ Batch of masks indicating padded part.
+
+ Returns:
+ Tensor: Batch of predicted sequences `(batch_size, max_text_length, 1)`.
+ """
+ # (batch_size, input_dim, max_text_length)
+ hidden_states = encoder_hidden_states.transpose(1, -1)
+ for layer in self.conv_layers:
+ hidden_states = layer(hidden_states)
+
+ hidden_states = self.linear(hidden_states.transpose(1, 2))
+
+ if padding_masks is not None:
+ hidden_states = hidden_states.masked_fill(padding_masks, 0.0)
+
+ return hidden_states
+
+
+class FastSpeech2ConformerVarianceEmbedding(nn.Module):
+ def __init__(
+ self,
+ in_channels=1,
+ out_channels=384,
+ kernel_size=1,
+ padding=0,
+ dropout_rate=0.0,
+ ):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ )
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class FastSpeech2ConformerAttention(nn.Module):
+ """
+ Multi-Head attention layer with relative position encoding. Details can be found in
+ https://github.com/espnet/espnet/pull/2816. Paper: https://huggingface.co/papers/1901.02860.
+ """
+
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config):
+ """Construct an FastSpeech2ConformerAttention object."""
+ super().__init__()
+ # We assume d_v always equals dim_key
+ self.num_heads = module_config["num_attention_heads"]
+ self.hidden_size = config.hidden_size
+ self.dim_key = self.hidden_size // self.num_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.linear_q = nn.Linear(self.hidden_size, self.hidden_size)
+ self.linear_k = nn.Linear(self.hidden_size, self.hidden_size)
+ self.linear_v = nn.Linear(self.hidden_size, self.hidden_size)
+ self.linear_out = nn.Linear(self.hidden_size, self.hidden_size)
+ self.dropout = nn.Dropout(p=module_config["attention_dropout_rate"])
+
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://huggingface.co/papers/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_dim))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_dim))
+
+ def shift_relative_position_tensor(self, pos_tensor):
+ """
+ Args:
+ pos_tensor (torch.Tensor of shape (batch_size, head, time1, 2*time1-1)): Input tensor.
+ """
+ zero_pad = torch.zeros((*pos_tensor.size()[:3], 1), device=pos_tensor.device, dtype=pos_tensor.dtype)
+ pos_tensor_padded = torch.cat([zero_pad, pos_tensor], dim=-1)
+
+ pos_tensor_padded = pos_tensor_padded.view(*pos_tensor.size()[:2], pos_tensor.size(3) + 1, pos_tensor.size(2))
+ # only keep the positions from 0 to time2
+ pos_tensor = pos_tensor_padded[:, :, 1:].view_as(pos_tensor)[:, :, :, : pos_tensor.size(-1) // 2 + 1]
+
+ return pos_tensor
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ pos_emb: Optional[torch.Tensor] = None,
+ output_attentions: Optional[torch.Tensor] = False,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch, time2, size)`): Values of the hidden states
+ attention_mask (`torch.Tensor` of shape `(batch, time1, time2)`): Mask tensor.
+ pos_emb (`torch.Tensor` of shape `(batch, 2*time1-1, size)`): Positional embedding tensor.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ Returns:
+ `torch.Tensor`: Output tensor of shape `(batch, time1, d_model)`.
+ """
+ bsz, q_len, _ = hidden_states.size()
+ query_states = self.linear_q(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
+ key_states = self.linear_k(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
+ value_states = self.linear_v(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
+
+ bsz_pos = pos_emb.size(0)
+ pos_encoding = self.linear_pos(pos_emb).view(bsz_pos, -1, self.num_heads, self.head_dim)
+
+ # (batch_size, head, time1, dim_key)
+ query_with_bias_u = (query_states + self.pos_bias_u).transpose(1, 2)
+ # (batch_size, head, time1, dim_key)
+ query_with_bias_v = (query_states + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://huggingface.co/papers/1901.02860 Section 3.3
+ # (batch_size, head, time1, time2)
+ matrix_ac = torch.matmul(query_with_bias_u, key_states.permute(0, 2, 3, 1))
+
+ # compute matrix b and matrix d
+ # (batch_size, head, time1, 2*time1-1)
+ matrix_bd = torch.matmul(query_with_bias_v, pos_encoding.permute(0, 2, 3, 1))
+ matrix_bd = self.shift_relative_position_tensor(matrix_bd)
+
+ # (batch_size, head, time1, time2)
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.dim_key)
+
+ # Forward attention
+ if attention_mask is not None:
+ expected_size = (bsz, 1, q_len)
+ if attention_mask.size() != expected_size:
+ raise ValueError(f"Attention mask should be of size {expected_size}, but is {attention_mask.size()}")
+ attention_mask = attention_mask.unsqueeze(1).eq(0)
+ min_value = float(torch.finfo(scores.dtype).min)
+ scores = scores.masked_fill(attention_mask, min_value)
+ attn_weights = torch.softmax(scores, dim=-1).masked_fill(attention_mask, 0.0)
+ else:
+ attn_weights = torch.softmax(scores, dim=-1)
+
+ attn_weights = self.dropout(attn_weights)
+ attn_output = torch.matmul(attn_weights, value_states.transpose(1, 2))
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
+
+ attn_output = self.linear_out(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+class FastSpeech2ConformerConvolutionModule(nn.Module):
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config=None):
+ """
+ Args:
+ config (FastSpeech2ConformerConfig): Configuration for the model.
+ module_config (dict): Configuration for the module (e.g., encoder or decoder).
+ """
+ super().__init__()
+ channels = config.hidden_size
+ # kernel_size should be an odd number for 'SAME' padding
+ if module_config is None:
+ # e.g. using `ParakeetEncoderConfig` in src/transformers/models/parakeet/configuration_parakeet.py
+ kernel_size = config.conv_kernel_size
+ self.activation = ACT2FN[getattr(config, "hidden_act", "silu")]
+ else:
+ kernel_size = module_config["kernel_size"]
+ self.activation = ACT2FN[module_config.get("activation", "silu")]
+ self.padding = (kernel_size - 1) // 2
+ self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True)
+ self.depthwise_conv = nn.Conv1d(
+ channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True
+ )
+ self.norm = nn.BatchNorm1d(channels)
+ self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
+
+ def forward(self, hidden_states, attention_mask=None):
+ """
+ Compute convolution module.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
+ attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.
+
+ Returns:
+ `torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
+
+ """
+ # exchange the temporal dimension and the feature dimension
+ hidden_states = hidden_states.transpose(1, 2)
+
+ # GLU mechanism, (batch_size, 2*channel, dim)
+ hidden_states = self.pointwise_conv1(hidden_states)
+ # (batch_size, channel, dim)
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
+
+ # Apply padding mask before convolution
+ if attention_mask is not None:
+ all_masked_rows = torch.all(~attention_mask, dim=-1)
+ hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
+
+ # 1D Depthwise Conv
+ hidden_states = self.depthwise_conv(hidden_states)
+ hidden_states = self.norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.pointwise_conv2(hidden_states)
+
+ return hidden_states.transpose(1, 2)
+
+
+class FastSpeech2ConformerEncoderLayer(nn.Module):
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config):
+ super().__init__()
+
+ # self-attention module definition
+ self.self_attn = FastSpeech2ConformerAttention(config, module_config)
+
+ # feed-forward module definition
+ self.feed_forward = FastSpeech2ConformerMultiLayeredConv1d(config, module_config)
+
+ self.macaron_style = config.use_macaron_style_in_conformer
+ if self.macaron_style:
+ self.feed_forward_macaron = FastSpeech2ConformerMultiLayeredConv1d(config, module_config)
+ self.ff_macaron_layer_norm = nn.LayerNorm(config.hidden_size)
+ self.ff_scale = 0.5
+ else:
+ self.ff_scale = 1.0
+
+ # convolution module definition
+ self.use_cnn_module = config.use_cnn_in_conformer
+ if self.use_cnn_module:
+ self.conv_module = FastSpeech2ConformerConvolutionModule(config, module_config)
+ self.conv_layer_norm = nn.LayerNorm(config.hidden_size)
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size)
+
+ self.ff_layer_norm = nn.LayerNorm(config.hidden_size)
+
+ self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
+
+ self.dropout = nn.Dropout(module_config["dropout_rate"])
+ self.size = config.hidden_size
+ self.normalize_before = module_config["normalize_before"]
+ self.concat_after = module_config["concat_after"]
+ if self.concat_after:
+ self.concat_linear = nn.Linear(config.hidden_size + config.hidden_size, config.hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ pos_emb: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[torch.Tensor] = False,
+ ):
+ """
+ Compute encoded features.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch, time, size)`): Input tensor.
+ pos_emb (`torch.Tensor` of shape `(1, time, size)`): Positional embeddings tensor.
+ attention_mask (`torch.Tensor` of shape `(batch, time)`): Attention mask tensor for the input.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ Returns:
+ `torch.Tensor`: Output tensor of shape `(batch, time, size)`.
+
+ """
+ # whether to use macaron style
+ if self.macaron_style:
+ residual = hidden_states
+ if self.normalize_before:
+ hidden_states = self.ff_macaron_layer_norm(hidden_states)
+ hidden_states = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(hidden_states))
+ if not self.normalize_before:
+ hidden_states = self.ff_macaron_layer_norm(hidden_states)
+
+ # multi-headed self-attention module
+ residual = hidden_states
+ if self.normalize_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ attention_output, attention_scores = self.self_attn(
+ hidden_states, attention_mask=attention_mask, pos_emb=pos_emb, output_attentions=output_attentions
+ )
+
+ if self.concat_after:
+ x_concat = torch.cat((hidden_states, attention_output), dim=-1)
+ hidden_states = self.concat_linear(x_concat)
+ hidden_states = residual + hidden_states
+ else:
+ hidden_states = self.dropout(attention_output)
+ hidden_states = residual + hidden_states
+ if not self.normalize_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # convolution module
+ if self.use_cnn_module:
+ residual = hidden_states
+ if self.normalize_before:
+ hidden_states = self.conv_layer_norm(hidden_states)
+ hidden_states = self.conv_module(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = residual + hidden_states
+ if not self.normalize_before:
+ hidden_states = self.conv_layer_norm(hidden_states)
+
+ # feed forward module
+ residual = hidden_states
+ if self.normalize_before:
+ hidden_states = self.ff_layer_norm(hidden_states)
+ hidden_states = self.feed_forward(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = residual + self.ff_scale * hidden_states
+ if not self.normalize_before:
+ hidden_states = self.ff_layer_norm(hidden_states)
+
+ if self.conv_module is not None:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attention_scores,)
+
+ return outputs
+
+
+class FastSpeech2ConformerMultiLayeredConv1d(nn.Module):
+ """
+ Multi-layered conv1d for Transformer block.
+
+ This is a module of multi-layered conv1d designed to replace positionwise feed-forward network in Transformer
+ block, which is introduced in 'FastSpeech: Fast, Robust and Controllable Text to Speech'
+ https://huggingface.co/papers/1905.09263
+ """
+
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config):
+ """
+ Initialize FastSpeech2ConformerMultiLayeredConv1d module.
+
+ Args:
+ input_channels (`int`): Number of input channels.
+ hidden_channels (`int`): Number of hidden channels.
+ kernel_size (`int`): Kernel size of conv1d.
+ dropout_rate (`float`): Dropout rate.
+ """
+ super().__init__()
+ input_channels = config.hidden_size
+ hidden_channels = module_config["linear_units"]
+ kernel_size = config.positionwise_conv_kernel_size
+ self.conv1 = nn.Conv1d(input_channels, hidden_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
+ self.conv2 = nn.Conv1d(hidden_channels, input_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
+ self.dropout = nn.Dropout(module_config["dropout_rate"])
+
+ def forward(self, hidden_states):
+ """
+ Calculate forward propagation.
+
+ Args:
+ hidden_states (torch.Tensor): Batch of input tensors (batch_size, time, input_channels).
+
+ Returns:
+ torch.Tensor: Batch of output tensors (batch_size, time, hidden_channels).
+ """
+ hidden_states = hidden_states.transpose(-1, 1)
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = torch.relu(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = hidden_states.transpose(-1, 1)
+ return hidden_states
+
+
+class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
+ """
+ Args:
+ Relative positional encoding module (new implementation). Details can be found in
+ https://github.com/espnet/espnet/pull/2816. See : Appendix Batch in https://huggingface.co/papers/1901.02860
+ config (`FastSpeech2ConformerConfig`):
+ FastSpeech2ConformerConfig instance.
+ module_config (`dict`):
+ Dictionary containing the encoder or decoder module configuration from the `FastSpeech2ConformerConfig`.
+ """
+
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config):
+ """
+ Construct an PositionalEncoding object.
+ """
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.input_scale = math.sqrt(self.embed_dim)
+ self.dropout = nn.Dropout(p=module_config["positional_dropout_rate"])
+ self.pos_enc = None
+ self.max_len = 5000
+ self.extend_pos_enc(torch.tensor(0.0).expand(1, self.max_len))
+
+ def extend_pos_enc(self, x):
+ """Reset the positional encodings."""
+ if self.pos_enc is not None:
+ # self.pos_enc contains both positive and negative parts
+ # the length of self.pos_enc is 2 * input_len - 1
+ if self.pos_enc.size(1) >= x.size(1) * 2 - 1:
+ if self.pos_enc.dtype != x.dtype or self.pos_enc.device != x.device:
+ self.pos_enc = self.pos_enc.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` means to the position of query vector and `j` means the
+ # position of key vector. We use position relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (i 1
+ if self.multilingual_model:
+ self.language_id_embedding = torch.nn.Embedding(config.num_languages, self.hidden_size)
+
+ self.multispeaker_model = config.num_speakers is not None and config.num_speakers > 1
+ if self.multispeaker_model:
+ self.speaker_id_embedding = torch.nn.Embedding(config.num_speakers, config.hidden_size)
+
+ self.speaker_embed_dim = config.speaker_embed_dim
+ if self.speaker_embed_dim:
+ self.projection = nn.Linear(config.hidden_size + self.speaker_embed_dim, config.hidden_size)
+
+ self.encoder = FastSpeech2ConformerEncoder(config, config.encoder_config, use_encoder_input_layer=True)
+
+ self.duration_predictor = FastSpeech2ConformerDurationPredictor(config)
+
+ self.pitch_predictor = FastSpeech2ConformerVariancePredictor(
+ config,
+ num_layers=config.pitch_predictor_layers,
+ num_chans=config.pitch_predictor_channels,
+ kernel_size=config.pitch_predictor_kernel_size,
+ dropout_rate=config.pitch_predictor_dropout,
+ )
+ # continuous pitch + FastPitch style avg
+ self.pitch_embed = FastSpeech2ConformerVarianceEmbedding(
+ out_channels=self.hidden_size,
+ kernel_size=config.pitch_embed_kernel_size,
+ padding=(config.pitch_embed_kernel_size - 1) // 2,
+ dropout_rate=config.pitch_embed_dropout,
+ )
+
+ self.energy_predictor = FastSpeech2ConformerVariancePredictor(
+ config,
+ num_layers=config.energy_predictor_layers,
+ num_chans=config.energy_predictor_channels,
+ kernel_size=config.energy_predictor_kernel_size,
+ dropout_rate=config.energy_predictor_dropout,
+ )
+ # continuous energy + FastPitch style avg
+ self.energy_embed = FastSpeech2ConformerVarianceEmbedding(
+ out_channels=self.hidden_size,
+ kernel_size=config.energy_embed_kernel_size,
+ padding=(config.energy_embed_kernel_size - 1) // 2,
+ dropout_rate=config.energy_embed_dropout,
+ )
+
+ # The decoder is an encoder
+ self.decoder = FastSpeech2ConformerEncoder(config, config.decoder_config, use_encoder_input_layer=False)
+
+ self.speech_decoder_postnet = FastSpeech2ConformerSpeechDecoderPostnet(config)
+
+ self.criterion = FastSpeech2ConformerLoss(config)
+
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ spectrogram_labels: Optional[torch.FloatTensor] = None,
+ duration_labels: Optional[torch.LongTensor] = None,
+ pitch_labels: Optional[torch.FloatTensor] = None,
+ energy_labels: Optional[torch.FloatTensor] = None,
+ speaker_ids: Optional[torch.LongTensor] = None,
+ lang_ids: Optional[torch.LongTensor] = None,
+ speaker_embedding: Optional[torch.FloatTensor] = None,
+ return_dict: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> Union[tuple, FastSpeech2ConformerModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Input sequence of text vectors.
+ spectrogram_labels (`torch.FloatTensor` of shape `(batch_size, max_spectrogram_length, num_mel_bins)`, *optional*, defaults to `None`):
+ Batch of padded target features.
+ duration_labels (`torch.LongTensor` of shape `(batch_size, sequence_length + 1)`, *optional*, defaults to `None`):
+ Batch of padded durations.
+ pitch_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
+ Batch of padded token-averaged pitch.
+ energy_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
+ Batch of padded token-averaged energy.
+ speaker_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
+ Speaker ids used to condition features of speech output by the model.
+ lang_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
+ Language ids used to condition features of speech output by the model.
+ speaker_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`, *optional*, defaults to `None`):
+ Embedding containing conditioning signals for the features of the speech.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... FastSpeech2ConformerTokenizer,
+ ... FastSpeech2ConformerModel,
+ ... FastSpeech2ConformerHifiGan,
+ ... )
+
+ >>> tokenizer = FastSpeech2ConformerTokenizer.from_pretrained("espnet/fastspeech2_conformer")
+ >>> inputs = tokenizer("some text to convert to speech", return_tensors="pt")
+ >>> input_ids = inputs["input_ids"]
+
+ >>> model = FastSpeech2ConformerModel.from_pretrained("espnet/fastspeech2_conformer")
+ >>> output_dict = model(input_ids, return_dict=True)
+ >>> spectrogram = output_dict["spectrogram"]
+
+ >>> vocoder = FastSpeech2ConformerHifiGan.from_pretrained("espnet/fastspeech2_conformer_hifigan")
+ >>> waveform = vocoder(spectrogram)
+ >>> print(waveform.shape)
+ torch.Size([1, 49664])
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
+
+ has_missing_labels = (
+ spectrogram_labels is None or duration_labels is None or pitch_labels is None or energy_labels is None
+ )
+ if self.training and has_missing_labels:
+ raise ValueError("All labels must be provided to run in training mode.")
+
+ # forward encoder
+ text_masks = attention_mask.unsqueeze(-2)
+
+ encoder_outputs = self.encoder(
+ input_ids,
+ text_masks,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+ hidden_states = encoder_outputs[0]
+
+ # Integrate with language id, speaker id, and speaker embedding
+ if self.multispeaker_model and speaker_ids is not None:
+ speaker_id_embeddings = self.speaker_id_embedding(speaker_ids.view(-1))
+ hidden_states = hidden_states + speaker_id_embeddings.unsqueeze(1)
+
+ if self.multilingual_model and lang_ids is not None:
+ language_id_embbedings = self.language_id_embedding(lang_ids.view(-1))
+ hidden_states = hidden_states + language_id_embbedings.unsqueeze(1)
+
+ if self.speaker_embed_dim is not None and speaker_embedding is not None:
+ embeddings_expanded = (
+ nn.functional.normalize(speaker_embedding).unsqueeze(1).expand(-1, hidden_states.size(1), -1)
+ )
+ hidden_states = self.projection(torch.cat([hidden_states, embeddings_expanded], dim=-1))
+
+ # forward duration predictor and variance predictors
+ duration_mask = ~attention_mask.bool()
+
+ if self.stop_gradient_from_pitch_predictor:
+ pitch_predictions = self.pitch_predictor(hidden_states.detach(), duration_mask.unsqueeze(-1))
+ else:
+ pitch_predictions = self.pitch_predictor(hidden_states, duration_mask.unsqueeze(-1))
+
+ if self.stop_gradient_from_energy_predictor:
+ energy_predictions = self.energy_predictor(hidden_states.detach(), duration_mask.unsqueeze(-1))
+ else:
+ energy_predictions = self.energy_predictor(hidden_states, duration_mask.unsqueeze(-1))
+
+ duration_predictions = self.duration_predictor(hidden_states)
+ duration_predictions = duration_predictions.masked_fill(duration_mask, 0.0)
+
+ if not self.training:
+ # use prediction in inference
+ embedded_pitch_curve = self.pitch_embed(pitch_predictions)
+ embedded_energy_curve = self.energy_embed(energy_predictions)
+ hidden_states = hidden_states + embedded_energy_curve + embedded_pitch_curve
+ hidden_states = length_regulator(hidden_states, duration_predictions, self.config.speaking_speed)
+ else:
+ # use groundtruth in training
+ embedded_pitch_curve = self.pitch_embed(pitch_labels)
+ embedded_energy_curve = self.energy_embed(energy_labels)
+ hidden_states = hidden_states + embedded_energy_curve + embedded_pitch_curve
+ hidden_states = length_regulator(hidden_states, duration_labels)
+
+ # forward decoder
+ if not self.training:
+ hidden_mask = None
+ else:
+ spectrogram_mask = (spectrogram_labels != -100).any(dim=-1)
+ spectrogram_mask = spectrogram_mask.int()
+ if self.reduction_factor > 1:
+ length_dim = spectrogram_mask.shape[1] - spectrogram_mask.shape[1] % self.reduction_factor
+ spectrogram_mask = spectrogram_mask[:, :, :length_dim]
+ hidden_mask = spectrogram_mask.unsqueeze(-2)
+
+ decoder_outputs = self.decoder(
+ hidden_states,
+ hidden_mask,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+
+ outputs_before_postnet, outputs_after_postnet = self.speech_decoder_postnet(decoder_outputs[0])
+
+ loss = None
+ if self.training:
+ # calculate loss
+ loss_duration_mask = ~duration_mask
+ loss_spectrogram_mask = spectrogram_mask.unsqueeze(-1).bool()
+ loss = self.criterion(
+ outputs_after_postnet=outputs_after_postnet,
+ outputs_before_postnet=outputs_before_postnet,
+ duration_outputs=duration_predictions,
+ pitch_outputs=pitch_predictions,
+ energy_outputs=energy_predictions,
+ spectrogram_labels=spectrogram_labels,
+ duration_labels=duration_labels,
+ pitch_labels=pitch_labels,
+ energy_labels=energy_labels,
+ duration_mask=loss_duration_mask,
+ spectrogram_mask=loss_spectrogram_mask,
+ )
+
+ if not return_dict:
+ postnet_outputs = (outputs_after_postnet,)
+ audio_feature_predictions = (
+ duration_predictions,
+ pitch_predictions,
+ energy_predictions,
+ )
+ outputs = postnet_outputs + encoder_outputs + decoder_outputs[1:] + audio_feature_predictions
+ return ((loss,) + outputs) if loss is not None else outputs
+
+ return FastSpeech2ConformerModelOutput(
+ loss=loss,
+ spectrogram=outputs_after_postnet,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ duration_outputs=duration_predictions,
+ pitch_outputs=pitch_predictions,
+ energy_outputs=energy_predictions,
+ )
+
+
+# Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock
+class HifiGanResidualBlock(nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
+ super().__init__()
+ self.leaky_relu_slope = leaky_relu_slope
+
+ self.convs1 = nn.ModuleList(
+ [
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ dilation=dilation[i],
+ padding=self.get_padding(kernel_size, dilation[i]),
+ )
+ for i in range(len(dilation))
+ ]
+ )
+ self.convs2 = nn.ModuleList(
+ [
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ padding=self.get_padding(kernel_size, 1),
+ )
+ for _ in range(len(dilation))
+ ]
+ )
+
+ def get_padding(self, kernel_size, dilation=1):
+ return (kernel_size * dilation - dilation) // 2
+
+ def apply_weight_norm(self):
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
+
+ for layer in self.convs1:
+ weight_norm(layer)
+ for layer in self.convs2:
+ weight_norm(layer)
+
+ def remove_weight_norm(self):
+ for layer in self.convs1:
+ nn.utils.remove_weight_norm(layer)
+ for layer in self.convs2:
+ nn.utils.remove_weight_norm(layer)
+
+ def forward(self, hidden_states):
+ for conv1, conv2 in zip(self.convs1, self.convs2):
+ residual = hidden_states
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
+ hidden_states = conv1(hidden_states)
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
+ hidden_states = conv2(hidden_states)
+ hidden_states = hidden_states + residual
+ return hidden_states
+
+
+@auto_docstring(
+ custom_intro="""
+ HiFi-GAN vocoder.
+ """
+)
+# Copied from transformers.models.speecht5.modeling_speecht5.SpeechT5HifiGan with SpeechT5->FastSpeech2Conformer
+class FastSpeech2ConformerHifiGan(PreTrainedModel):
+ config: FastSpeech2ConformerHifiGanConfig
+ main_input_name = "spectrogram"
+
+ def __init__(self, config: FastSpeech2ConformerHifiGanConfig):
+ super().__init__(config)
+ self.num_kernels = len(config.resblock_kernel_sizes)
+ self.num_upsamples = len(config.upsample_rates)
+ self.conv_pre = nn.Conv1d(
+ config.model_in_dim,
+ config.upsample_initial_channel,
+ kernel_size=7,
+ stride=1,
+ padding=3,
+ )
+
+ self.upsampler = nn.ModuleList()
+ for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
+ self.upsampler.append(
+ nn.ConvTranspose1d(
+ config.upsample_initial_channel // (2**i),
+ config.upsample_initial_channel // (2 ** (i + 1)),
+ kernel_size=kernel_size,
+ stride=upsample_rate,
+ padding=(kernel_size - upsample_rate) // 2,
+ )
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.upsampler)):
+ channels = config.upsample_initial_channel // (2 ** (i + 1))
+ for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
+ self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
+
+ self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3)
+
+ self.register_buffer("mean", torch.zeros(config.model_in_dim))
+ self.register_buffer("scale", torch.ones(config.model_in_dim))
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights."""
+ if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ def apply_weight_norm(self):
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
+
+ weight_norm(self.conv_pre)
+ for layer in self.upsampler:
+ weight_norm(layer)
+ for layer in self.resblocks:
+ layer.apply_weight_norm()
+ weight_norm(self.conv_post)
+
+ def remove_weight_norm(self):
+ nn.utils.remove_weight_norm(self.conv_pre)
+ for layer in self.upsampler:
+ nn.utils.remove_weight_norm(layer)
+ for layer in self.resblocks:
+ layer.remove_weight_norm()
+ nn.utils.remove_weight_norm(self.conv_post)
+
+ @auto_docstring(
+ custom_intro="""
+ Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch
+ of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech
+ waveform.
+ """
+ )
+ def forward(self, spectrogram: torch.FloatTensor) -> torch.FloatTensor:
+ r"""
+ spectrogram (`torch.FloatTensor`):
+ Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length,
+ config.model_in_dim)`, or un-batched and of shape `(sequence_length, config.model_in_dim)`.
+
+ Returns:
+ `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of
+ shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`.
+ """
+ if self.config.normalize_before:
+ spectrogram = (spectrogram - self.mean) / self.scale
+
+ is_batched = spectrogram.dim() == 3
+ if not is_batched:
+ spectrogram = spectrogram.unsqueeze(0)
+
+ hidden_states = spectrogram.transpose(2, 1)
+
+ hidden_states = self.conv_pre(hidden_states)
+ for i in range(self.num_upsamples):
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
+ hidden_states = self.upsampler[i](hidden_states)
+
+ res_state = self.resblocks[i * self.num_kernels](hidden_states)
+ for j in range(1, self.num_kernels):
+ res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
+ hidden_states = res_state / self.num_kernels
+
+ hidden_states = nn.functional.leaky_relu(hidden_states)
+ hidden_states = self.conv_post(hidden_states)
+ hidden_states = torch.tanh(hidden_states)
+
+ if not is_batched:
+ # remove batch dim and collapse tensor to 1-d audio waveform
+ waveform = hidden_states.squeeze(0).transpose(1, 0).view(-1)
+ else:
+ # remove seq-len dim since this collapses to 1
+ waveform = hidden_states.squeeze(1)
+
+ return waveform
+
+
+@auto_docstring(
+ custom_intro="""
+ The FastSpeech2ConformerModel with a FastSpeech2ConformerHifiGan vocoder head that performs text-to-speech (waveform).
+ """
+)
+class FastSpeech2ConformerWithHifiGan(PreTrainedModel):
+ config: FastSpeech2ConformerWithHifiGanConfig
+
+ def __init__(self, config: FastSpeech2ConformerWithHifiGanConfig):
+ super().__init__(config)
+
+ self.model = FastSpeech2ConformerModel(config.model_config)
+ self.vocoder = FastSpeech2ConformerHifiGan(config.vocoder_config)
+
+ self.config = config
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ spectrogram_labels: Optional[torch.FloatTensor] = None,
+ duration_labels: Optional[torch.LongTensor] = None,
+ pitch_labels: Optional[torch.FloatTensor] = None,
+ energy_labels: Optional[torch.FloatTensor] = None,
+ speaker_ids: Optional[torch.LongTensor] = None,
+ lang_ids: Optional[torch.LongTensor] = None,
+ speaker_embedding: Optional[torch.FloatTensor] = None,
+ return_dict: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> Union[tuple, FastSpeech2ConformerModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Input sequence of text vectors.
+ spectrogram_labels (`torch.FloatTensor` of shape `(batch_size, max_spectrogram_length, num_mel_bins)`, *optional*, defaults to `None`):
+ Batch of padded target features.
+ duration_labels (`torch.LongTensor` of shape `(batch_size, sequence_length + 1)`, *optional*, defaults to `None`):
+ Batch of padded durations.
+ pitch_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
+ Batch of padded token-averaged pitch.
+ energy_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
+ Batch of padded token-averaged energy.
+ speaker_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
+ Speaker ids used to condition features of speech output by the model.
+ lang_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
+ Language ids used to condition features of speech output by the model.
+ speaker_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`, *optional*, defaults to `None`):
+ Embedding containing conditioning signals for the features of the speech.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... FastSpeech2ConformerTokenizer,
+ ... FastSpeech2ConformerWithHifiGan,
+ ... )
+
+ >>> tokenizer = FastSpeech2ConformerTokenizer.from_pretrained("espnet/fastspeech2_conformer")
+ >>> inputs = tokenizer("some text to convert to speech", return_tensors="pt")
+ >>> input_ids = inputs["input_ids"]
+
+ >>> model = FastSpeech2ConformerWithHifiGan.from_pretrained("espnet/fastspeech2_conformer_with_hifigan")
+ >>> output_dict = model(input_ids, return_dict=True)
+ >>> waveform = output_dict["waveform"]
+ >>> print(waveform.shape)
+ torch.Size([1, 49664])
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.model_config.use_return_dict
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.config.model_config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.model_config.output_hidden_states
+ )
+
+ model_outputs = self.model(
+ input_ids,
+ attention_mask,
+ spectrogram_labels=spectrogram_labels,
+ duration_labels=duration_labels,
+ pitch_labels=pitch_labels,
+ energy_labels=energy_labels,
+ speaker_ids=speaker_ids,
+ lang_ids=lang_ids,
+ speaker_embedding=speaker_embedding,
+ return_dict=return_dict,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ if not return_dict:
+ has_missing_labels = (
+ spectrogram_labels is None or duration_labels is None or pitch_labels is None or energy_labels is None
+ )
+ if has_missing_labels:
+ spectrogram = model_outputs[0]
+ else:
+ spectrogram = model_outputs[1]
+ else:
+ spectrogram = model_outputs["spectrogram"]
+ waveform = self.vocoder(spectrogram)
+
+ if not return_dict:
+ return model_outputs + (waveform,)
+
+ return FastSpeech2ConformerWithHifiGanOutput(waveform=waveform, **model_outputs)
+
+
+__all__ = [
+ "FastSpeech2ConformerWithHifiGan",
+ "FastSpeech2ConformerHifiGan",
+ "FastSpeech2ConformerModel",
+ "FastSpeech2ConformerPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..004a1c36f59cc7942a7a132012bdccb40a4a38de
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py
@@ -0,0 +1,188 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for FastSpeech2Conformer."""
+
+import json
+import os
+from typing import Optional
+
+import regex
+
+from ...tokenization_utils import PreTrainedTokenizer
+from ...utils import logging, requires_backends
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"}
+
+
+class FastSpeech2ConformerTokenizer(PreTrainedTokenizer):
+ """
+ Construct a FastSpeech2Conformer tokenizer.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The begin of sequence token. Note that for FastSpeech2, it is the same as the `eos_token`.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token. Note that for FastSpeech2, it is the same as the `bos_token`.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ should_strip_spaces (`bool`, *optional*, defaults to `False`):
+ Whether or not to strip the spaces from the list of tokens.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ bos_token="",
+ eos_token="",
+ pad_token="",
+ unk_token="",
+ should_strip_spaces=False,
+ **kwargs,
+ ):
+ requires_backends(self, "g2p_en")
+
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+
+ import g2p_en
+
+ self.g2p = g2p_en.G2p()
+
+ self.decoder = {v: k for k, v in self.encoder.items()}
+
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ should_strip_spaces=should_strip_spaces,
+ **kwargs,
+ )
+
+ self.should_strip_spaces = should_strip_spaces
+
+ @property
+ def vocab_size(self):
+ return len(self.decoder)
+
+ def get_vocab(self):
+ "Returns vocab as a dict"
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
+ # expand symbols
+ text = regex.sub(";", ",", text)
+ text = regex.sub(":", ",", text)
+ text = regex.sub("-", " ", text)
+ text = regex.sub("&", "and", text)
+
+ # strip unnecessary symbols
+ text = regex.sub(r"[\(\)\[\]\<\>\"]+", "", text)
+
+ # strip whitespaces
+ text = regex.sub(r"\s+", " ", text)
+
+ text = text.upper()
+
+ return text, kwargs
+
+ def _tokenize(self, text):
+ """Returns a tokenized string."""
+ # phonemize
+ tokens = self.g2p(text)
+
+ if self.should_strip_spaces:
+ tokens = list(filter(lambda s: s != " ", tokens))
+
+ tokens.append(self.eos_token)
+
+ return tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index, self.unk_token)
+
+ # Override since phonemes cannot be converted back to strings
+ def decode(self, token_ids, **kwargs):
+ logger.warning(
+ "Phonemes cannot be reliably converted to a string due to the one-many mapping, converting to tokens instead."
+ )
+ return self.convert_ids_to_tokens(token_ids)
+
+ # Override since phonemes cannot be converted back to strings
+ def convert_tokens_to_string(self, tokens, **kwargs):
+ logger.warning(
+ "Phonemes cannot be reliably converted to a string due to the one-many mapping, returning the tokens."
+ )
+ return tokens
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ """
+ Save the vocabulary and special tokens file to a directory.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.get_vocab(), ensure_ascii=False))
+
+ return (vocab_file,)
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["g2p"] = None
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+
+ try:
+ import g2p_en
+
+ self.g2p = g2p_en.G2p()
+ except ImportError:
+ raise ImportError(
+ "You need to install g2p-en to use FastSpeech2ConformerTokenizer. "
+ "See https://pypi.org/project/g2p-en/ for installation."
+ )
+
+
+__all__ = ["FastSpeech2ConformerTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2a7d252010e00ec7e3192520ac401b200dc1da9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_fuyu import *
+ from .image_processing_fuyu import *
+ from .modeling_fuyu import *
+ from .processing_fuyu import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/configuration_fuyu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/configuration_fuyu.py
new file mode 100644
index 0000000000000000000000000000000000000000..40da84e2e780821f26765333a2cee51030e0bea4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/configuration_fuyu.py
@@ -0,0 +1,215 @@
+# coding=utf-8
+# Copyright 2023 Adept AI and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fuyu model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class FuyuConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`FuyuForCausalLM`]. It is used to instantiate an
+ Fuyu model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the
+ [adept/fuyu-8b](https://huggingface.co/adept/fuyu-8b).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 262144):
+ Vocabulary size of the Fuyu model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`FuyuForCausalLM`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 16384):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 36):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 64):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 16384):
+ The maximum sequence length that this model might ever be used with.
+ image_size (`int`, *optional*, defaults to 300):
+ The input image size.
+ patch_size (`int`, *optional*, defaults to 30):
+ The input vision transformer encoding patch size.
+ num_channels (`int`, *optional*, defaults to 3):
+ The input image number of channels.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie input and output embeddings.
+ rope_theta (`float`, *optional*, defaults to 25000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalFuyu/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ qk_layernorm (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the Queries and Keys after projecting the hidden states
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio after applying the MLP to the hidden states.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio after computing the attention scores.
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5):
+ Percentage of the query and keys which will have rotary embedding.
+
+ pad_token_id (`int`, *optional*):
+ The id of the *padding* token.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the *beginning-of-sequence* token.
+ eos_token_id (`Union[int, list[int]]`, *optional*, defaults to 2):
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
+ image_token_id (`int`, *optional*, defaults to 71011):
+ The id of the image placeholder token.
+ text_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize the `language``[`Aut`].
+
+ ```python
+ >>> from transformers import FuyuConfig
+
+ >>> # Initializing a Fuyu fuyu-7b style configuration
+ >>> configuration = FuyuConfig()
+ ```"""
+
+ model_type = "fuyu"
+ sub_configs = {"text_config": AutoConfig}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=262144,
+ hidden_size=4096,
+ intermediate_size=16384,
+ num_hidden_layers=36,
+ num_attention_heads=64,
+ hidden_act="relu2",
+ max_position_embeddings=16384,
+ image_size=300,
+ patch_size=30,
+ num_channels=3,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=25000.0,
+ rope_scaling=None,
+ qk_layernorm=True,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ partial_rotary_factor=0.5,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ image_token_id=71011,
+ text_config=None,
+ **kwargs,
+ ):
+ if text_config is None:
+ text_config = {
+ "vocab_size": vocab_size,
+ "max_position_embeddings": max_position_embeddings,
+ "hidden_size": hidden_size,
+ "intermediate_size": intermediate_size,
+ "num_hidden_layers": num_hidden_layers,
+ "num_attention_heads": num_attention_heads,
+ "hidden_act": hidden_act,
+ "initializer_range": initializer_range,
+ "layer_norm_eps": layer_norm_eps,
+ "use_cache": use_cache,
+ "rope_theta": rope_theta,
+ "rope_scaling": rope_scaling,
+ "qk_layernorm": qk_layernorm,
+ "hidden_dropout": hidden_dropout,
+ "attention_dropout": attention_dropout,
+ "partial_rotary_factor": partial_rotary_factor,
+ "pad_token_id": pad_token_id,
+ "bos_token_id": bos_token_id,
+ "eos_token_id": eos_token_id,
+ "tie_word_embeddings": tie_word_embeddings,
+ }
+ logger.info("text_config is None. initializing the text model with default values.")
+ text_model_type = text_config.get("model_type", "persimmon")
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
+
+ self._vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.qk_layernorm = qk_layernorm
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+ self.partial_rotary_factor = partial_rotary_factor
+ self.image_token_id = image_token_id
+ self._rope_scaling_validation()
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
+ raise ValueError(
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
+ )
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
+
+
+__all__ = ["FuyuConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/image_processing_fuyu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/image_processing_fuyu.py
new file mode 100644
index 0000000000000000000000000000000000000000..366782be16f486d27f8c1f803685935a55909554
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/image_processing_fuyu.py
@@ -0,0 +1,728 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Fuyu."""
+
+import math
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ pad,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ is_valid_image,
+ make_list_of_images,
+ to_numpy_array,
+ validate_preprocess_arguments,
+)
+from ...utils import (
+ TensorType,
+ filter_out_non_signature_kwargs,
+ is_torch_available,
+ is_torch_device,
+ is_torch_dtype,
+ logging,
+ requires_backends,
+)
+
+
+if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__)
+
+
+def make_list_of_list_of_images(
+ images: Union[list[list[ImageInput]], list[ImageInput], ImageInput],
+) -> list[list[ImageInput]]:
+ if is_valid_image(images):
+ return [[images]]
+
+ if isinstance(images, list) and all(isinstance(image, list) for image in images):
+ return images
+
+ if isinstance(images, list):
+ return [make_list_of_images(image) for image in images]
+
+ raise ValueError("images must be a list of list of images or a list of images or an image.")
+
+
+class FuyuBatchFeature(BatchFeature):
+ """
+ BatchFeature class for Fuyu image processor and processor.
+
+ The outputs dictionary from the processors contains a mix of tensors and lists of tensors.
+ """
+
+ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
+ """
+ Convert the inner content to tensors.
+
+ Args:
+ tensor_type (`str` or [`~utils.TensorType`], *optional*):
+ The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
+ `None`, no modification is done.
+ """
+ if tensor_type is None:
+ return self
+
+ is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type=tensor_type)
+
+ def _convert_tensor(elem):
+ if is_tensor(elem):
+ return elem
+ return as_tensor(elem)
+
+ def _safe_convert_tensor(elem):
+ try:
+ return _convert_tensor(elem)
+ except: # noqa E722
+ if key == "overflowing_values":
+ raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
+ raise ValueError(
+ "Unable to create tensor, you should probably activate padding "
+ "with 'padding=True' to have batched tensors with the same length."
+ )
+
+ # Do the tensor conversion in batch
+ for key, value in self.items():
+ if isinstance(value, list) and isinstance(value[0], list):
+ # list[list[Any]] -> list[list[Tensor]]
+ self[key] = [[_safe_convert_tensor(elem) for elem in elems] for elems in value]
+ elif isinstance(value, list):
+ # list[Any] -> list[Tensor]
+ self[key] = [_safe_convert_tensor(elem) for elem in value]
+ else:
+ # Any -> Tensor
+ self[key] = _safe_convert_tensor(value)
+ return self
+
+ def to(self, *args, **kwargs) -> "BatchFeature":
+ """
+ Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
+ different `dtypes` and sending the `BatchFeature` to a different `device`.
+
+ Args:
+ args (`Tuple`):
+ Will be passed to the `to(...)` function of the tensors.
+ kwargs (`Dict`, *optional*):
+ Will be passed to the `to(...)` function of the tensors.
+
+ Returns:
+ [`BatchFeature`]: The same instance after modification.
+ """
+ requires_backends(self, ["torch"])
+ import torch
+
+ new_data = {}
+ device = kwargs.get("device")
+ # Check if the args are a device or a dtype
+ if device is None and len(args) > 0:
+ # device should be always the first argument
+ arg = args[0]
+ if is_torch_dtype(arg):
+ # The first argument is a dtype
+ pass
+ elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
+ device = arg
+ else:
+ # it's something else
+ raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
+
+ def _to(elem):
+ # check if v is a floating point
+ if torch.is_floating_point(elem):
+ # cast and send to device
+ return elem.to(*args, **kwargs)
+ if device is not None:
+ return elem.to(device=device)
+
+ return elem
+
+ # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
+ for k, v in self.items():
+ if isinstance(v, list) and isinstance(v[0], list):
+ # Data structure is a list of lists
+ new_v = []
+ for elems in v:
+ new_v.append([_to(elem) for elem in elems])
+ new_data[k] = new_v
+ elif isinstance(v, list):
+ # Data structure is a list
+ new_data[k] = [_to(elem) for elem in v]
+ else:
+ new_data[k] = _to(v)
+ self.data = new_data
+ return self
+
+
+class FuyuImageProcessor(BaseImageProcessor):
+ """
+ This class should handle the image processing part before the main FuyuForCausalLM. In particular, it should
+ handle:
+
+ - Processing Images:
+ Taking a batch of images as input. If the images are variable-sized, it resizes them based on the desired patch
+ dimensions. The image output is always img_h, img_w of (1080, 1920)
+
+ Then, it patches up these images using the patchify_image function.
+
+ - Creating Image Input IDs:
+ For each patch, a placeholder ID is given to identify where these patches belong in a token sequence. For
+ variable-sized images, each line of patches is terminated with a newline ID.
+
+ - Image Patch Indices:
+ For each image patch, the code maintains an index where these patches should be inserted in a token stream.
+
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image to `size`.
+ size (`dict[str, int]`, *optional*, defaults to `{"height": 1080, "width": 1920}`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image to `size`.
+ padding_value (`float`, *optional*, defaults to 1.0):
+ The value to pad the image with.
+ padding_mode (`str`, *optional*, defaults to `"constant"`):
+ The padding mode to use when padding the image.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image.
+ image_mean (`float`, *optional*, defaults to 0.5):
+ The mean to use when normalizing the image.
+ image_std (`float`, *optional*, defaults to 0.5):
+ The standard deviation to use when normalizing the image.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `1 / 255`):
+ The factor to use when rescaling the image.
+ patch_size (`dict[str, int]`, *optional*, defaults to `{"height": 30, "width": 30}`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
+ """
+
+ model_input_names = [
+ "images",
+ "image_input_ids",
+ "image_patches",
+ "image_patch_indices_per_batch",
+ "image_patch_indices_per_subsequence",
+ ]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_pad: bool = True,
+ padding_value: float = 1.0,
+ padding_mode: str = "constant",
+ do_normalize: bool = True,
+ image_mean: Union[float, list[float]] = 0.5,
+ image_std: Union[float, list[float]] = 0.5,
+ do_rescale: bool = True,
+ rescale_factor: float = 1 / 255,
+ patch_size: Optional[dict[str, int]] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.size = size if size is not None else {"height": 1080, "width": 1920}
+ self.resample = resample
+ self.do_pad = do_pad
+ self.padding_value = padding_value
+ self.padding_mode = padding_mode
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.patch_size = patch_size if patch_size is not None else {"height": 30, "width": 30}
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ image_height, image_width = get_image_size(image, input_data_format)
+ target_height, target_width = size["height"], size["width"]
+
+ if image_width <= target_width and image_height <= target_height:
+ return image
+
+ height_scale_factor = target_height / image_height
+ width_scale_factor = target_width / image_width
+ optimal_scale_factor = min(height_scale_factor, width_scale_factor)
+
+ new_height = int(image_height * optimal_scale_factor)
+ new_width = int(image_width * optimal_scale_factor)
+
+ scaled_image = resize(
+ image=image,
+ size=(new_height, new_width),
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+ return scaled_image
+
+ def pad_image(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ mode: str = "constant",
+ constant_values: float = 1.0,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pad an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to pad.
+ size (`dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The data format of the output image. If unset, the same format as the input image is used.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ image_height, image_width = get_image_size(image, input_data_format)
+ target_height, target_width = size["height"], size["width"]
+ padding_top = 0
+ padding_left = 0
+ padding_bottom = target_height - image_height
+ padding_right = target_width - image_width
+ padded_image = pad(
+ image,
+ padding=((padding_top, padding_bottom), (padding_left, padding_right)),
+ mode=mode,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ return padded_image
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_pad: Optional[bool] = None,
+ padding_value: Optional[float] = None,
+ padding_mode: Optional[str] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[float] = None,
+ image_std: Optional[float] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ patch_size: Optional[dict[str, int]] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ return_tensors: Optional[TensorType] = None,
+ ):
+ """
+
+ Utility function to preprocess the images and extract necessary information about original formats.
+
+ Args:
+ images (`ImageInput`):
+ Images to preprocess. Expects a single image, a list or images or a list of lists of images. Pixel
+ values range from 0 to 255, or between 0 and 1 if `do_rescale` is `False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image to `size`.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image to `size`.
+ padding_value (`float`, *optional*, defaults to `self.padding_value`):
+ The value to pad the image with.
+ padding_mode (`str`, *optional*, defaults to `self.padding_mode`):
+ The padding mode to use when padding the image.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float`, *optional*, defaults to `self.image_mean`):
+ The mean to use when normalizing the image.
+ image_std (`float`, *optional*, defaults to `self.image_std`):
+ The standard deviation to use when normalizing the image.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ The factor to use when rescaling the image.
+ patch_size (`dict[str, int]`, *optional*, defaults to `self.patch_size`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format of the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ resample = resample if resample is not None else self.resample
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ padding_value = padding_value if padding_value is not None else self.padding_value
+ padding_mode = padding_mode if padding_mode is not None else self.padding_mode
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ patch_size = patch_size if patch_size is not None else self.patch_size
+
+ if isinstance(images, list) and any(isinstance(elem, list) and len(elem) >= 2 for elem in images):
+ raise ValueError("Multiple images for a single sample are not yet supported.")
+
+ batch_images = make_list_of_list_of_images(images)
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+ # All transformations expect numpy arrays.
+ batch_images = [[to_numpy_array(image) for image in images] for images in batch_images]
+
+ # Search for the first image in the image list.
+ # NOTE: we can't slice the first image with images_list[0][0] if the first batch contains no images. See #36682
+ first_image_in_list = [images for images in batch_images if images][0][0]
+
+ if do_rescale and is_scaled_image(first_image_in_list):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(first_image_in_list)
+
+ original_image_sizes = [
+ get_image_size(images[0], channel_dim=input_data_format) for images in batch_images if images
+ ]
+ size = get_size_dict(size) # for BC
+
+ if do_resize:
+ batch_images = [
+ [self.resize(image, size=size, input_data_format=input_data_format) for image in images]
+ for images in batch_images
+ ]
+
+ image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images if images]
+ image_unpadded_heights = [[image_size[0]] for image_size in image_sizes]
+ image_unpadded_widths = [[image_size[1]] for image_size in image_sizes]
+
+ # scale_h is the same as scale_w
+ image_scale_factors = [
+ [resized_size[0] / original_size[0]]
+ for original_size, resized_size in zip(original_image_sizes, image_sizes)
+ ]
+
+ if do_pad:
+ batch_images = [
+ [
+ self.pad_image(
+ image,
+ size=size,
+ mode=padding_mode,
+ constant_values=padding_value,
+ input_data_format=input_data_format,
+ )
+ for image in images
+ ]
+ for images in batch_images
+ ]
+
+ if do_rescale:
+ batch_images = [
+ [self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) for image in images]
+ for images in batch_images
+ ]
+
+ if do_normalize:
+ batch_images = [
+ [
+ self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+ for images in batch_images
+ ]
+
+ if data_format is not None:
+ batch_images = [
+ [to_channel_dimension_format(image, data_format, input_data_format) for image in images]
+ for images in batch_images
+ ]
+
+ data = {
+ "images": batch_images,
+ "image_unpadded_heights": image_unpadded_heights,
+ "image_unpadded_widths": image_unpadded_widths,
+ "image_scale_factors": image_scale_factors,
+ }
+ return FuyuBatchFeature(data=data, tensor_type=return_tensors)
+
+ def get_num_patches(self, image_height: int, image_width: int, patch_size: Optional[dict[str, int]] = None) -> int:
+ """
+ Calculate number of patches required to encode an image.
+
+ Args:
+ image_height (`int`):
+ Height of the image.
+ image_width (`int`):
+ Width of the image.
+ patch_size (`dict[str, int]`, *optional*, defaults to `self.patch_size`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
+ """
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
+
+ if image_height % patch_height != 0:
+ raise ValueError(f"{image_height=} must be divisible by {patch_height}")
+ if image_width % patch_width != 0:
+ raise ValueError(f"{image_width=} must be divisible by {patch_width}")
+
+ num_patches_per_dim_h = image_height // patch_height
+ num_patches_per_dim_w = image_width // patch_width
+ num_patches = num_patches_per_dim_h * num_patches_per_dim_w
+ return num_patches
+
+ def patchify_image(self, image: "torch.Tensor", patch_size: Optional[dict[str, int]] = None) -> "torch.Tensor":
+ """
+ Convert an image into a tensor of patches.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to convert. Shape: [batch, channels, height, width]
+ patch_size (`dict[str, int]`, *optional*, defaults to `self.patch_size`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
+ """
+ requires_backends(self, ["torch"])
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ patch_height, patch_width = patch_size["height"], patch_size["width"]
+
+ # TODO refer to https://github.com/ArthurZucker/transformers/blob/0f0a3fe5ca5697ee58faeb5b53f049af720b5e98/src/transformers/models/vit_mae/modeling_vit_mae.py#L871
+ # torch implementation is faster but does not handle non-squares
+
+ batch_size, channels, _, _ = image.shape
+ unfolded_along_height = image.unfold(2, patch_height, patch_height)
+ patches = unfolded_along_height.unfold(3, patch_width, patch_width)
+ patches = patches.contiguous()
+ patches = patches.view(batch_size, channels, -1, patch_height, patch_width)
+ patches = patches.permute(0, 2, 3, 4, 1)
+ patches = patches.reshape(batch_size, -1, channels * patch_height * patch_width)
+ return patches
+
+ def preprocess_with_tokenizer_info(
+ self,
+ image_input: "torch.Tensor",
+ image_present: "torch.Tensor",
+ image_unpadded_h: "torch.Tensor",
+ image_unpadded_w: "torch.Tensor",
+ image_placeholder_id: int,
+ image_newline_id: int,
+ variable_sized: bool,
+ patch_size: Optional[dict[str, int]] = None,
+ ) -> FuyuBatchFeature:
+ """Process images for model input. In particular, variable-sized images are handled here.
+
+ Args:
+ image_input (`torch.Tensor` of shape [batch_size, subsequence_size, num_channels, height, width]):
+ Tensor of images padded to model input size.
+ image_present (`torch.Tensor` of shape [batch_size, subsequence_size, num_images]):
+ Tensor of 1s and 0s indicating whether an image is present.
+ image_unpadded_h (`torch.Tensor` of shape [batch_size, subsequence_size]):
+ Tensor of unpadded image heights.
+ image_unpadded_w (`torch.Tensor` of shape [batch_size, subsequence_size]):
+ Tensor of unpadded image widths.
+ image_placeholder_id (int):
+ The id of the image placeholder token. Comes from an associated tokenizer.
+ image_newline_id (int):
+ The id of the image newline token. Comes from an associated tokenizer.
+ variable_sized (bool):
+ Whether to process images as variable-sized.
+ patch_size (`dict[str, int]`, *optional*, defaults to `self.patch_size`):
+ Size of the patches.
+ """
+ requires_backends(self, ["torch"])
+
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ patch_height, patch_width = patch_size["height"], patch_size["width"]
+
+ # Only images that are present.
+ images: list[list[torch.Tensor]] = []
+ batch_image_patches: list[list[torch.Tensor]] = []
+ # Image input ids for every subsequence, including ones with no image present.
+ batch_image_input_ids: list[list[torch.Tensor]] = []
+ for batch_index in range(image_input.shape[0]):
+ image_input_ids = []
+ image_patches = []
+ for subseq_index in range(image_input.shape[1]):
+ if image_present[batch_index, subseq_index]:
+ image = image_input[batch_index, subseq_index]
+ image_height, image_width = image.shape[1], image.shape[2]
+ if variable_sized:
+ # The min() is required here due to floating point issues:
+ # math.ceil(torch.tensor(300).cuda() / 30) == 11
+ new_h = min(
+ image_height,
+ math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height,
+ )
+ new_w = min(
+ image_width,
+ math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width,
+ )
+ image = image[:, :new_h, :new_w]
+ image_height, image_width = new_h, new_w
+
+ num_patches = self.get_num_patches(image_height=image_height, image_width=image_width)
+ tensor_of_image_ids = torch.full(
+ [num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device
+ )
+ patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0)
+ assert num_patches == patches.shape[0]
+
+ if variable_sized:
+ # Now terminate each line with |NEWLINE|.
+ tensor_of_image_ids = tensor_of_image_ids.reshape(-1, image_width // patch_width)
+ newline_ids = torch.full(
+ [tensor_of_image_ids.shape[0], 1],
+ image_newline_id,
+ dtype=torch.int32,
+ device=image_input.device,
+ )
+ tensor_of_image_ids = torch.cat([tensor_of_image_ids, newline_ids], dim=1)
+ tensor_of_image_ids = tensor_of_image_ids.reshape(-1)
+
+ images.append([image])
+ image_input_ids.append(tensor_of_image_ids)
+ image_patches.append(patches)
+ else:
+ image_input_ids.append(torch.tensor([], dtype=torch.int32, device=image_input.device))
+
+ batch_image_input_ids.append(image_input_ids)
+ batch_image_patches.append(image_patches)
+
+ # Create image_patch_input_indices, where non-negative values correspond to image patches to be inserted in
+ # the stream.
+ image_patch_indices_per_batch: list[list[torch.Tensor]] = []
+ image_patch_indices_per_subsequence: list[list[torch.Tensor]] = []
+
+ for sample_image_input_ids in batch_image_input_ids:
+ index_offset = 0
+ per_batch_indices = []
+ per_subsequence_indices = []
+ for subseq_image_input_ids in sample_image_input_ids:
+ # Indices of image patches.
+ patches_mask = subseq_image_input_ids == image_placeholder_id
+ num_patches = torch.count_nonzero(patches_mask)
+ indices = torch.arange(num_patches, dtype=torch.int64, device=subseq_image_input_ids.device).type_as(
+ subseq_image_input_ids
+ )
+
+ # Place those indices in the image input ids token stream, with -1 representing non-index tokens.
+ indices_in_stream_per_batch = torch.full_like(subseq_image_input_ids, -1)
+ indices_in_stream_per_subsequence = torch.full_like(subseq_image_input_ids, -1)
+ patches_inds = torch.nonzero(patches_mask, as_tuple=True)[0]
+
+ indices_in_stream_per_batch[patches_inds] = indices + index_offset
+ indices_in_stream_per_subsequence[patches_inds] = indices
+
+ per_batch_indices.append(indices_in_stream_per_batch)
+ per_subsequence_indices.append(indices_in_stream_per_subsequence)
+ index_offset += num_patches
+
+ image_patch_indices_per_batch.append(per_batch_indices)
+ image_patch_indices_per_subsequence.append(per_subsequence_indices)
+
+ return FuyuBatchFeature(
+ data={
+ "images": images,
+ "image_input_ids": batch_image_input_ids,
+ "image_patches": batch_image_patches,
+ "image_patch_indices_per_batch": image_patch_indices_per_batch,
+ "image_patch_indices_per_subsequence": image_patch_indices_per_subsequence,
+ }
+ )
+
+
+__all__ = ["FuyuImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/modeling_fuyu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/modeling_fuyu.py
new file mode 100644
index 0000000000000000000000000000000000000000..2095e9877c2c7759dd1d8668b65ecbb2f980ca30
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/modeling_fuyu.py
@@ -0,0 +1,407 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Fuyu model."""
+
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...modeling_outputs import CausalLMOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...models.auto.modeling_auto import AutoModel
+from ...utils import auto_docstring, can_return_tuple, logging
+from .configuration_fuyu import FuyuConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@auto_docstring
+class FuyuPreTrainedModel(PreTrainedModel):
+ config: FuyuConfig
+ base_model_prefix = "fuyu"
+ supports_gradient_checkpointing = True
+ _supports_attention_backend = True
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _no_split_modules = []
+ _skip_keys_device_placement = "past_key_values"
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+@auto_docstring(
+ custom_intro="""
+ The Fuyu model which consists of a vision backbone and a language model, without a language modeling head.
+ """
+)
+class FuyuModel(FuyuPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
+ def __init__(self, config: FuyuConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.text_config.vocab_size
+ self.language_model = AutoModel.from_config(config.text_config)
+ self.vision_embed_tokens = nn.Linear(
+ config.patch_size * config.patch_size * config.num_channels, config.hidden_size
+ )
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def gather_continuous_embeddings(
+ self,
+ word_embeddings: torch.Tensor,
+ continuous_embeddings: list[torch.Tensor],
+ image_patch_input_indices: torch.Tensor,
+ ) -> torch.Tensor:
+ """This function places the continuous_embeddings into the word_embeddings at the locations
+ indicated by image_patch_input_indices. Different batch elements can have different numbers of continuous
+ embeddings.
+
+ Args:
+ word_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Tensor of word embeddings.
+ continuous_embeddings (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
+ Tensor of continuous embeddings. The length of the list is the batch size. Each entry is shape
+ [num_image_embeddings, hidden], and num_image_embeddings needs to match the number of non-negative
+ indices in image_patch_input_indices for that batch element.
+ image_patch_input_indices (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Tensor of indices of the image patches in the input_ids tensor.
+ """
+ if not (word_embeddings.shape[0] == len(continuous_embeddings)):
+ raise ValueError(
+ f"Batch sizes must match! Got {len(continuous_embeddings)=} and {word_embeddings.shape[0]=}"
+ )
+
+ output_embeddings = word_embeddings.clone()
+ for batch_idx in range(word_embeddings.shape[0]):
+ # First, find the positions of all the non-negative values in image_patch_input_indices, those are the
+ # positions in word_embeddings that we want to replace with content from continuous_embeddings.
+ dst_indices = torch.nonzero(image_patch_input_indices[batch_idx] >= 0, as_tuple=True)[0]
+ # Next look up those indices in image_patch_input_indices to find the indices in continuous_embeddings that we
+ # want to use to replace the values in word_embeddings.
+ src_indices = image_patch_input_indices[batch_idx][dst_indices]
+ # Check if we have more indices than embeddings. Note that we could have fewer indices if images got truncated.
+ if src_indices.shape[0] > continuous_embeddings[batch_idx].shape[0]:
+ raise ValueError(
+ f"Number of continuous embeddings {continuous_embeddings[batch_idx].shape=} does not match "
+ f"number of continuous token ids {src_indices.shape=} in batch element {batch_idx}."
+ )
+ output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices].to(
+ output_embeddings.device
+ )
+ return output_embeddings
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ """
+ patch_embeddings = [
+ self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0)
+ for patch in pixel_values
+ ]
+ return patch_embeddings
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ return special_image_mask
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
+ image_patches: Optional[torch.Tensor] = None,
+ image_patches_indices: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ image_patches (`torch.FloatTensor` of shape `(batch_size, num_total_patches, patch_size_ x patch_size x num_channels)`, *optional*):
+ Image patches to be used as continuous embeddings. The patches are flattened and then projected to the
+ hidden size of the model.
+ image_patches_indices (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Tensor of indices of the image patches in the input_ids tensor.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_is or inputs_embeds")
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
+
+ if image_patches is not None:
+ patch_embeddings = self.get_image_features(image_patches)
+ patch_embeddings = torch.cat(patch_embeddings, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=patch_embeddings
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings)
+
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ return_dict=return_dict,
+ **kwargs,
+ )
+
+ return outputs
+
+
+@auto_docstring(
+ custom_intro="""
+ Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.
+ """
+)
+class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_embed_tokens": "model.vision_embed_tokens",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: FuyuConfig):
+ super().__init__(config)
+ self.model = FuyuModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
+ image_patches: Optional[torch.Tensor] = None,
+ image_patches_indices: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ logits_to_keep: Optional[int] = 0,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ image_patches (`torch.FloatTensor` of shape `(batch_size, num_total_patches, patch_size_ x patch_size x num_channels)`, *optional*):
+ Image patches to be used as continuous embeddings. The patches are flattened and then projected to the
+ hidden size of the model.
+ image_patches_indices (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Tensor of indices of the image patches in the input_ids tensor.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import FuyuProcessor, FuyuForCausalLM
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> processor = FuyuProcessor.from_pretrained("adept/fuyu-8b")
+ >>> model = FuyuForCausalLM.from_pretrained("adept/fuyu-8b")
+
+ >>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> prompt = "Generate a coco-style caption.\n"
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=7)
+ >>> generation_text = processor.batch_decode(generated_ids[:, -7:], skip_special_tokens=True)
+ >>> print(generation_text[0])
+ A blue bus parked on the side of a road.
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ image_patches=image_patches,
+ image_patches_indices=image_patches_indices,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ return_dict=True,
+ # don't pass kwargs because Persimmon-backbone doesn't accept FA2 kwargs yet, TODO: raushan
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ image_patches=None,
+ image_patches_indices=None,
+ cache_position=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ image_patches=image_patches,
+ image_patches_indices=image_patches_indices,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ if cache_position[0] != 0:
+ # set image_patches and image_patches_indices to `None` for decoding stage
+ model_inputs["image_patches_indices"] = None
+ model_inputs["image_patches"] = None
+
+ return model_inputs
+
+
+__all__ = ["FuyuForCausalLM", "FuyuPreTrainedModel", "FuyuModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/processing_fuyu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/processing_fuyu.py
new file mode 100644
index 0000000000000000000000000000000000000000..debbcb23aac1f9df63f052bdb39883a22efc71b9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fuyu/processing_fuyu.py
@@ -0,0 +1,793 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Image/Text processor class for GIT
+"""
+
+import re
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_utils import ImageInput
+from ...processing_utils import (
+ MultiModalData,
+ ProcessingKwargs,
+ ProcessorMixin,
+ Unpack,
+)
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import is_torch_available, logging, requires_backends
+from ...utils.import_utils import requires
+
+
+if is_torch_available():
+ from .image_processing_fuyu import FuyuBatchFeature
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_torch_available():
+ import torch
+
+
+TEXT_REPR_BBOX_OPEN = ""
+TEXT_REPR_BBOX_CLOSE = ""
+TEXT_REPR_POINT_OPEN = ""
+TEXT_REPR_POINT_CLOSE = ""
+
+TOKEN_BBOX_OPEN_STRING = "<0x00>" #
+TOKEN_BBOX_CLOSE_STRING = "<0x01>" #
+TOKEN_POINT_OPEN_STRING = "<0x02>" #
+TOKEN_POINT_CLOSE_STRING = "<0x03>" #
+BEGINNING_OF_ANSWER_STRING = "<0x04>" #
+
+
+class FuyuProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {
+ "add_special_tokens": True,
+ "padding": False,
+ "stride": 0,
+ "return_attention_mask": True,
+ "return_overflowing_tokens": False,
+ "return_special_tokens_mask": False,
+ "return_offsets_mapping": False,
+ "return_token_type_ids": False,
+ "return_length": False,
+ "verbose": True,
+ "return_mm_token_type_ids": False,
+ },
+ "images_kwargs": {},
+ }
+
+
+def full_unpacked_stream_to_tensor(
+ all_bi_tokens_to_place: list[int],
+ full_unpacked_stream: list["torch.Tensor"],
+ fill_value: int,
+ batch_size: int,
+ new_seq_len: int,
+ offset: int,
+) -> "torch.Tensor":
+ """Takes an unpacked stream of tokens (i.e. a list of tensors, one for each item in the batch) and does
+ the required padding to create a single tensor for the batch of shape batch_size x new_seq_len.
+ """
+
+ assert len(all_bi_tokens_to_place) == batch_size
+ assert len(full_unpacked_stream) == batch_size
+
+ # Create padded tensors for the full batch.
+ new_padded_tensor = torch.full(
+ [batch_size, new_seq_len],
+ fill_value=fill_value,
+ dtype=full_unpacked_stream[0].dtype,
+ device=full_unpacked_stream[0].device,
+ )
+
+ # Place each batch entry into the batch tensor.
+ for bi in range(batch_size):
+ tokens_to_place = all_bi_tokens_to_place[bi]
+ new_padded_tensor[bi, :tokens_to_place] = full_unpacked_stream[bi][offset : tokens_to_place + offset]
+
+ return new_padded_tensor
+
+
+def construct_full_unpacked_stream(
+ num_real_text_tokens: Union[list[list[int]], "torch.Tensor"],
+ input_stream: "torch.Tensor",
+ image_tokens: list[list["torch.Tensor"]],
+ batch_size: int,
+ num_sub_sequences: int,
+) -> list["torch.Tensor"]:
+ """Takes an input_stream tensor of shape B x S x ?. For each subsequence, adds any required
+ padding to account for images and then unpacks the subsequences to create a single sequence per item in the batch.
+ Returns a list of tensors, one for each item in the batch."""
+
+ all_bi_stream = []
+
+ for batch_index in range(batch_size):
+ all_si_stream = []
+
+ # First, construct full token stream (including image placeholder tokens) and loss mask for each subsequence
+ # and append to lists. We use lists rather than tensors because each subsequence is variable-sized.
+ # TODO Remove this logic in a subsequent release since subsequences are not supported.
+ image_adjustment = image_tokens[batch_index][0]
+ subsequence_stream = torch.cat([image_adjustment, input_stream[batch_index, 0]], dim=0)
+ num_real_tokens = image_adjustment.shape[0] + num_real_text_tokens[batch_index][0]
+ all_si_stream.append(subsequence_stream[:num_real_tokens])
+ all_bi_stream.append(torch.cat(all_si_stream, dim=0))
+
+ return all_bi_stream
+
+
+def _replace_string_repr_with_token_tags(prompt: str) -> str:
+ prompt = prompt.replace(TEXT_REPR_POINT_OPEN, TOKEN_POINT_OPEN_STRING)
+ prompt = prompt.replace(TEXT_REPR_POINT_CLOSE, TOKEN_POINT_CLOSE_STRING)
+ prompt = prompt.replace(TEXT_REPR_BBOX_OPEN, TOKEN_BBOX_OPEN_STRING)
+ prompt = prompt.replace(TEXT_REPR_BBOX_CLOSE, TOKEN_BBOX_CLOSE_STRING)
+ return prompt
+
+
+def _segment_prompt_into_text_token_conversions(prompt: str) -> list:
+ """
+ Given a string prompt, converts the prompt into a list of TextTokenConversions.
+ """
+ # Wherever, we notice the [TOKEN_OPEN_STRING, TOKEN_CLOSE_STRING], we split the prompt
+ prompt_text_list: list = []
+ regex_pattern = re.compile(
+ f"({TOKEN_BBOX_OPEN_STRING}|{TOKEN_BBOX_CLOSE_STRING}|{TOKEN_POINT_OPEN_STRING}|{TOKEN_POINT_CLOSE_STRING})"
+ )
+ # Split by the regex pattern
+ prompt_split = regex_pattern.split(prompt)
+ for i, elem in enumerate(prompt_split):
+ if len(elem) == 0 or elem in [
+ TOKEN_BBOX_OPEN_STRING,
+ TOKEN_BBOX_CLOSE_STRING,
+ TOKEN_POINT_OPEN_STRING,
+ TOKEN_POINT_CLOSE_STRING,
+ ]:
+ continue
+ prompt_text_list.append(
+ (elem, i > 1 and prompt_split[i - 1] in [TOKEN_BBOX_OPEN_STRING, TOKEN_POINT_OPEN_STRING])
+ )
+ return prompt_text_list
+
+
+def _transform_coordinates_and_tokenize(prompt: str, scale_factor: float, tokenizer) -> list[int]:
+ """
+ This function transforms the prompt in the following fashion:
+ - and to their respective token mappings
+ - extract the coordinates from the tag
+ - transform the coordinates into the transformed image space
+ - return the prompt tokens with the transformed coordinates and new tags
+
+ Bounding boxes and points MUST be in the following format: y1, x1, y2, x2 x, y The spaces
+ and punctuation added above are NOT optional.
+ """
+ # Make a namedtuple that stores "text" and "is_bbox"
+
+ # We want to do the following: Tokenize the code normally -> when we see a point or box, tokenize using the tokenize_within_tag function
+ # When point or box close tag, continue tokenizing normally
+ # First, we replace the point and box tags with their respective tokens
+ prompt = _replace_string_repr_with_token_tags(prompt)
+ # Tokenize the prompt
+ # Convert prompt into a list split
+ prompt_text_list = _segment_prompt_into_text_token_conversions(prompt)
+ transformed_prompt_tokens: list[int] = []
+ for elem in prompt_text_list:
+ if elem[1]:
+ # This is a location, we need to tokenize it
+ within_tag_tokenized = _transform_within_tags(elem[0], scale_factor, tokenizer)
+ # Surround the text with the open and close tags
+ transformed_prompt_tokens.extend(within_tag_tokenized)
+ else:
+ transformed_prompt_tokens.extend(tokenizer(elem[0], add_special_tokens=False).input_ids)
+ return transformed_prompt_tokens
+
+
+def _transform_within_tags(text: str, scale_factor: float, tokenizer) -> list[int]:
+ """
+ Given a bounding box of the fashion 1, 2, 3, 4 | 1, 2 This function is responsible for
+ converting 1, 2, 3, 4 into tokens of 1 2 3 4 without any commas.
+ """
+ # Convert the text into a list of strings.
+ num_int_strs = text.split(",")
+ if len(num_int_strs) == 2:
+ # If there are any open or close tags, remove them.
+ token_space_open_string = tokenizer.vocab[TOKEN_POINT_OPEN_STRING]
+ token_space_close_string = tokenizer.vocab[TOKEN_POINT_CLOSE_STRING]
+ else:
+ token_space_open_string = tokenizer.vocab[TOKEN_BBOX_OPEN_STRING]
+ token_space_close_string = tokenizer.vocab[TOKEN_BBOX_CLOSE_STRING]
+
+ # Remove all spaces from num_ints
+ num_ints = [float(num.strip()) for num in num_int_strs]
+ # scale to transformed image size
+ if len(num_ints) == 2:
+ num_ints_translated = scale_point_to_transformed_image(x=num_ints[0], y=num_ints[1], scale_factor=scale_factor)
+ elif len(num_ints) == 4:
+ num_ints_translated = scale_bbox_to_transformed_image(
+ top=num_ints[0],
+ left=num_ints[1],
+ bottom=num_ints[2],
+ right=num_ints[3],
+ scale_factor=scale_factor,
+ )
+ else:
+ raise ValueError(f"Invalid number of ints: {len(num_ints)}")
+ # Tokenize the text, skipping the
+ tokens = [tokenizer.vocab[str(num)] for num in num_ints_translated]
+ return [token_space_open_string] + tokens + [token_space_close_string]
+
+
+def _tokenize_prompts_with_image_and_batch(
+ tokenizer,
+ prompts: list[list[str]],
+ scale_factors: Optional[list[list["torch.Tensor"]]],
+ max_tokens_to_generate: int,
+ max_position_embeddings: int,
+ add_BOS: bool, # Same issue with types as above
+ add_beginning_of_answer_token: bool,
+) -> tuple["torch.Tensor", "torch.Tensor"]:
+ """
+ Given a set of prompts and number of tokens to generate:
+ - tokenize prompts
+ - set the sequence length to be the max of length of prompts plus the number of tokens we would like to generate
+ - pad all the sequences to this length so we can convert them into a 3D tensor.
+ """
+
+ # If not tool use, transform the coordinates while tokenizing
+ if scale_factors is not None:
+ transformed_prompt_tokens = []
+ for prompt_seq, scale_factor_seq in zip(prompts, scale_factors):
+ transformed_prompt_tokens.append(
+ [
+ _transform_coordinates_and_tokenize(prompt, scale_factor.item(), tokenizer)
+ for prompt, scale_factor in zip(prompt_seq, scale_factor_seq)
+ ]
+ )
+ else:
+ transformed_prompt_tokens = [[tokenizer.tokenize(prompt) for prompt in prompt_seq] for prompt_seq in prompts]
+
+ prompts_tokens = transformed_prompt_tokens
+
+ if add_BOS:
+ bos_token = tokenizer.vocab[""]
+ else:
+ bos_token = tokenizer.vocab["|ENDOFTEXT|"]
+ prompts_tokens = [[[bos_token] + x for x in prompt_seq] for prompt_seq in prompts_tokens]
+ if add_beginning_of_answer_token:
+ beginning_of_answer = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
+ # Only add bbox open token to the last subsequence since that is what will be completed
+ for token_seq in prompts_tokens:
+ token_seq[-1].append(beginning_of_answer)
+
+ # Now we have a list of list of tokens which each list has a different
+ # size. We want to extend this list to:
+ # - incorporate the tokens that need to be generated
+ # - make all the sequences equal length.
+ # Get the prompts length.
+
+ prompts_length = [[len(x) for x in prompts_tokens_seq] for prompts_tokens_seq in prompts_tokens]
+ # Get the max prompts length.
+ max_prompt_len: int = np.max(prompts_length)
+ # Number of tokens in the each sample of the batch.
+ samples_length = min(max_prompt_len + max_tokens_to_generate, max_position_embeddings)
+ if max_prompt_len + max_tokens_to_generate > max_position_embeddings:
+ logger.warning(
+ f"Max subsequence prompt length of {max_prompt_len} + max tokens to generate {max_tokens_to_generate}",
+ f"exceeds context length of {max_position_embeddings}. Will generate as many tokens as possible.",
+ )
+ # Now update the list of list to be of the same size: samples_length.
+ for prompt_tokens_seq, prompts_length_seq in zip(prompts_tokens, prompts_length):
+ for prompt_tokens, prompt_length in zip(prompt_tokens_seq, prompts_length_seq):
+ if len(prompt_tokens) > samples_length:
+ raise ValueError("Length of subsequence prompt exceeds sequence length.")
+ padding_size = samples_length - prompt_length
+ prompt_tokens.extend([tokenizer.vocab["|ENDOFTEXT|"]] * padding_size)
+
+ # Now we are in a structured format, we can convert to tensors.
+ prompts_tokens_tensor = torch.tensor(prompts_tokens, dtype=torch.int64)
+ prompts_length_tensor = torch.tensor(prompts_length, dtype=torch.int64)
+
+ return prompts_tokens_tensor, prompts_length_tensor
+
+
+# Simplified assuming self.crop_top = self.padding_top = 0
+def original_to_transformed_h_coords(original_coords, scale_h):
+ return np.round(original_coords * scale_h).astype(np.int32)
+
+
+# Simplified assuming self.crop_left = self.padding_left = 0
+def original_to_transformed_w_coords(original_coords, scale_w):
+ return np.round(original_coords * scale_w).astype(np.int32)
+
+
+def scale_point_to_transformed_image(x: float, y: float, scale_factor: float) -> list[int]:
+ x_scaled = original_to_transformed_w_coords(np.array([x / 2]), scale_factor)[0]
+ y_scaled = original_to_transformed_h_coords(np.array([y / 2]), scale_factor)[0]
+ return [x_scaled, y_scaled]
+
+
+def scale_bbox_to_transformed_image(
+ top: float, left: float, bottom: float, right: float, scale_factor: float
+) -> list[int]:
+ top_scaled = original_to_transformed_w_coords(np.array([top / 2]), scale_factor)[0]
+ left_scaled = original_to_transformed_h_coords(np.array([left / 2]), scale_factor)[0]
+ bottom_scaled = original_to_transformed_w_coords(np.array([bottom / 2]), scale_factor)[0]
+ right_scaled = original_to_transformed_h_coords(np.array([right / 2]), scale_factor)[0]
+ return [top_scaled, left_scaled, bottom_scaled, right_scaled]
+
+
+@requires(backends=("vision",))
+class FuyuProcessor(ProcessorMixin):
+ r"""
+ Constructs a Fuyu processor which wraps a Fuyu image processor and a Llama tokenizer into a single processor.
+
+ [`FuyuProcessor`] offers all the functionalities of [`FuyuImageProcessor`] and [`LlamaTokenizerFast`]. See the
+ [`~FuyuProcessor.__call__`] and [`~FuyuProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`FuyuImageProcessor`]):
+ The image processor is a required input.
+ tokenizer ([`LlamaTokenizerFast`]):
+ The tokenizer is a required input.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "FuyuImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, image_processor, tokenizer, **kwargs):
+ super().__init__(image_processor=image_processor, tokenizer=tokenizer)
+ self.image_processor = image_processor
+ self.tokenizer = tokenizer
+ self.max_tokens_to_generate = 10
+ self.max_position_embeddings = 16384 # TODO Can't derive this from model files: where to set it?
+ self.pad_token_id = 0
+ self.dummy_image_index = -1
+ self.image_token_id = tokenizer.encode("|SPEAKER|", add_special_tokens=False)[1]
+ self.image_newline_id = tokenizer.encode("|NEWLINE|", add_special_tokens=False)[1]
+
+ def _left_pad_inputs_with_attention_mask(self, model_inputs: list[dict], return_attention_mask: bool):
+ max_length_input_ids = max(entry["input_ids"].shape[1] for entry in model_inputs)
+ max_length_image_patch_indices = max(entry["image_patches_indices"].shape[1] for entry in model_inputs)
+
+ batched_inputs = {"input_ids": [], "image_patches": [], "image_patches_indices": [], "attention_mask": []}
+
+ for entry in model_inputs:
+ for key, tensor in entry.items():
+ if key == "input_ids":
+ num_padding_tokens = max_length_input_ids - tensor.shape[1]
+ padded_input_ids = torch.cat(
+ [
+ torch.full((tensor.shape[0], num_padding_tokens), self.pad_token_id, dtype=torch.long),
+ tensor,
+ ],
+ dim=1,
+ )
+ batched_inputs[key].append(padded_input_ids)
+
+ attention_mask = torch.cat(
+ [torch.zeros(tensor.shape[0], num_padding_tokens, dtype=torch.long), torch.ones_like(tensor)],
+ dim=1,
+ )
+ batched_inputs["attention_mask"].append(attention_mask)
+
+ elif key == "image_patches":
+ # For image_patches, we don't pad but just append them to the list.
+ batched_inputs[key].append(tensor)
+
+ else: # for image_patches_indices
+ num_padding_indices = max_length_image_patch_indices - tensor.shape[1]
+ padded_indices = torch.cat(
+ [
+ torch.full(
+ (tensor.shape[0], num_padding_indices), self.dummy_image_index, dtype=torch.long
+ ),
+ tensor,
+ ],
+ dim=1,
+ )
+ batched_inputs[key].append(padded_indices)
+ batched_keys = ["input_ids", "image_patches_indices"]
+ if return_attention_mask:
+ batched_keys.append("attention_mask")
+ for key in batched_keys:
+ batched_inputs[key] = torch.cat(batched_inputs[key], dim=0)
+
+ # Cast images to tensor as well, if only one image passed and no padding needed
+ # NOTE: vLLM expects all processor outputs to be a tensor
+ if len(batched_inputs["image_patches"]) == 1:
+ batched_inputs["image_patches"] = torch.cat(batched_inputs["image_patches"], dim=0)
+
+ return batched_inputs
+
+ def get_sample_encoding(
+ self,
+ prompts,
+ scale_factors,
+ image_unpadded_heights,
+ image_unpadded_widths,
+ image_placeholder_id,
+ image_newline_id,
+ tensor_batch_images,
+ ):
+ image_present = torch.ones(1, 1, 1)
+ model_image_input = self.image_processor.preprocess_with_tokenizer_info(
+ image_input=tensor_batch_images,
+ image_present=image_present,
+ image_unpadded_h=image_unpadded_heights,
+ image_unpadded_w=image_unpadded_widths,
+ image_placeholder_id=image_placeholder_id,
+ image_newline_id=image_newline_id,
+ variable_sized=True,
+ )
+ # FIXME max_tokens_to_generate is embedded into this processor's call.
+ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
+ tokenizer=self.tokenizer,
+ prompts=prompts,
+ scale_factors=scale_factors,
+ max_tokens_to_generate=self.max_tokens_to_generate,
+ max_position_embeddings=self.max_position_embeddings,
+ add_BOS=True,
+ add_beginning_of_answer_token=True,
+ )
+ image_padded_unpacked_tokens = construct_full_unpacked_stream(
+ num_real_text_tokens=prompts_length,
+ input_stream=prompt_tokens,
+ image_tokens=model_image_input["image_input_ids"],
+ batch_size=1,
+ num_sub_sequences=self.subsequence_length,
+ )
+ # Construct inputs for image patch indices.
+ unpacked_image_patch_indices_per_batch = construct_full_unpacked_stream(
+ num_real_text_tokens=prompts_length,
+ input_stream=torch.full_like(prompt_tokens, -1),
+ image_tokens=model_image_input["image_patch_indices_per_batch"],
+ batch_size=1,
+ num_sub_sequences=self.subsequence_length,
+ )
+ max_prompt_length = max(x.shape[-1] for x in image_padded_unpacked_tokens)
+ max_seq_len_batch = min(max_prompt_length + self.max_tokens_to_generate, self.max_position_embeddings)
+ tokens_to_place = min(max_seq_len_batch, max(0, image_padded_unpacked_tokens[0].shape[0]))
+
+ # Use same packing logic for the image patch indices.
+ image_patch_input_indices = full_unpacked_stream_to_tensor(
+ all_bi_tokens_to_place=[tokens_to_place],
+ full_unpacked_stream=unpacked_image_patch_indices_per_batch,
+ fill_value=-1,
+ batch_size=1,
+ new_seq_len=max_seq_len_batch,
+ offset=0,
+ )
+ image_patches_tensor = torch.stack([img[0] for img in model_image_input["image_patches"]])
+ batch_encoding = {
+ "input_ids": image_padded_unpacked_tokens[0].unsqueeze(0),
+ "image_patches": image_patches_tensor,
+ "image_patches_indices": image_patch_input_indices,
+ }
+ return batch_encoding
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Optional[Union[str, list[str], TextInput, PreTokenizedInput]] = None,
+ audio=None,
+ videos=None,
+ **kwargs: Unpack[FuyuProcessorKwargs],
+ ) -> "FuyuBatchFeature":
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to
+ encode the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
+ FuyuImageProcessor's [`~FuyuImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
+ of the above two methods for more information.
+
+ Args:
+ images (`PIL.Image.Image`, `list[PIL.Image.Image]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `list[str]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+
+ Returns:
+ [`FuyuBatchEncoding`]: A [`FuyuBatchEncoding`] with the following fields:
+
+ - **input_ids** -- Tensor of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **image_patches** -- List of Tensor of image patches. Returned when `images` is not `None`.
+ - **image_patches_indices** -- Tensor of indices where patch embeddings have to be inserted by the model.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model when
+ `return_attention_mask=True`.
+ """
+ requires_backends(self, ["torch"])
+
+ # --- Check input validity ---
+ if text is None and images is None:
+ raise ValueError("You have to specify either text or images. Both cannot be None.")
+
+ output_kwargs = self._merge_kwargs(
+ FuyuProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
+
+ if not output_kwargs["text_kwargs"].setdefault("return_attention_mask", True):
+ raise ValueError("`return_attention_mask=False` is not supported for this model.")
+
+ if text is not None and images is None:
+ logger.warning("You are processing a text with no associated image. Make sure it is intended.")
+ self.current_processor = self.tokenizer
+ text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
+ return text_encoding
+
+ if text is None and images is not None:
+ logger.warning("You are processing an image with no associated text. Make sure it is intended.")
+ prompts = [[""]]
+ if text is not None and images is not None:
+ if isinstance(text, str):
+ prompts = [[text]]
+ elif isinstance(text, list):
+ prompts = [[text_seq] for text_seq in text]
+
+ # --- Preprocess images using self.image_processor ---
+
+ # FIXME - We hard code "pt" here because the rest of the processing assumes torch tensors
+ output_kwargs["images_kwargs"]["return_tensors"] = "pt"
+ image_encoding = self.image_processor.preprocess(images, **output_kwargs["images_kwargs"])
+ batch_images = image_encoding["images"]
+ image_unpadded_heights = image_encoding["image_unpadded_heights"]
+ image_unpadded_widths = image_encoding["image_unpadded_widths"]
+ scale_factors = image_encoding["image_scale_factors"]
+ self.subsequence_length = 1 # Each batch contains only one sequence.
+ self.batch_size = len(batch_images)
+
+ # --- Use self.tokenizer to get the ids of special tokens to insert into image ids ---
+
+ tensor_batch_images = torch.stack([img[0] for img in batch_images if img]).unsqueeze(1)
+
+ # --- Use self.image_processor again to obtain the full token ids and batch inputs ---
+ all_encodings = []
+
+ for prompt, scale_factor, image_unpadded_height, image_unpadded_width, tensor_batch_image in zip(
+ prompts, scale_factors, image_unpadded_heights, image_unpadded_widths, tensor_batch_images
+ ):
+ sample_encoding = self.get_sample_encoding(
+ prompts=[prompt],
+ scale_factors=[scale_factor],
+ image_unpadded_heights=torch.tensor([image_unpadded_height]),
+ image_unpadded_widths=torch.tensor([image_unpadded_width]),
+ image_placeholder_id=self.image_token_id,
+ image_newline_id=self.image_newline_id,
+ tensor_batch_images=tensor_batch_image.unsqueeze(0),
+ )
+ all_encodings.append(sample_encoding)
+
+ batch_encoding = self._left_pad_inputs_with_attention_mask(
+ model_inputs=all_encodings, return_attention_mask=True
+ )
+ if return_mm_token_type_ids:
+ input_ids = batch_encoding["input_ids"]
+ mm_token_type_ids = torch.zeros_like(input_ids)
+ mm_token_type_ids[input_ids == self.image_token_id] = 1
+ mm_token_type_ids[input_ids == self.image_newline_id] = 1
+ batch_encoding["mm_token_type_ids"] = mm_token_type_ids
+
+ return FuyuBatchFeature(data=batch_encoding)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+
+ Args:
+ image_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (height, width) per each image.
+
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+
+ vision_data = {}
+ if image_sizes is not None:
+ size = kwargs.get("size") or self.image_processor.size
+ padded_height, padded_width = size["height"], size["width"]
+
+ num_image_tokens = []
+ num_image_patches = [1] * len(image_sizes)
+ for image_size in image_sizes:
+ height_scale_factor = padded_height / image_size[0]
+ width_scale_factor = padded_width / image_size[1]
+ optimal_scale_factor = min(height_scale_factor, width_scale_factor)
+
+ image_unpadded_h = min(int(image_size[0] * optimal_scale_factor), image_size[0])
+ image_unpadded_w = min(int(image_size[0] * optimal_scale_factor), image_size[0])
+
+ # We can use torch here because Fuyu processor has hard dependency on torch. NOTE: Fuyu can't do multi-image
+ # thus the below (1, 1, 1) is hardcoded. Same as when calling the processor
+ model_image_input = self.image_processor.preprocess_with_tokenizer_info(
+ image_input=torch.zeros(1, 1, 3, padded_height, padded_width),
+ image_present=torch.ones(1, 1, 1),
+ image_unpadded_h=torch.tensor([[image_unpadded_h]]),
+ image_unpadded_w=torch.tensor([[image_unpadded_w]]),
+ image_placeholder_id=0, # dummy ids, we can be sure `id=0` is never out-of-range
+ image_newline_id=0,
+ variable_sized=True,
+ )
+ num_image_tokens.append(model_image_input["image_input_ids"][0][0].shape[-1])
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+ return MultiModalData(**vision_data)
+
+ def post_process_box_coordinates(self, outputs, target_sizes=None):
+ """
+ Transforms raw coordinates detected by [`FuyuForCausalLM`] to the original images' coordinate space.
+ Coordinates will be returned in "box" format, with the following pattern:
+ `top, left, bottom, right`
+
+ Point coordinates are not supported yet.
+
+ Args:
+ outputs ([`GenerateOutput`]):
+ Raw outputs from `generate`.
+ target_sizes (`torch.Tensor`, *optional*):
+ Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
+ the batch. If set, found coordinates in the output sequence are rescaled to the target sizes. If left
+ to None, coordinates will not be rescaled.
+
+ Returns:
+ `GenerateOutput`: Same output type returned by `generate`, with output token ids replaced with
+ boxed and possible rescaled coordinates.
+ """
+
+ def scale_factor_to_fit(original_size, target_size=None):
+ height, width = original_size
+ if target_size is None:
+ max_height = self.image_processor.size["height"]
+ max_width = self.image_processor.size["width"]
+ else:
+ max_height, max_width = target_size
+ if width <= max_width and height <= max_height:
+ return 1.0
+ return min(max_height / height, max_width / width)
+
+ def find_delimiters_pair(tokens, start_token, end_token):
+ start_id = self.tokenizer.convert_tokens_to_ids(start_token)
+ end_id = self.tokenizer.convert_tokens_to_ids(end_token)
+
+ starting_positions = (tokens == start_id).nonzero(as_tuple=True)[0]
+ ending_positions = (tokens == end_id).nonzero(as_tuple=True)[0]
+
+ if torch.any(starting_positions) and torch.any(ending_positions):
+ return (starting_positions[0], ending_positions[0])
+ return (None, None)
+
+ def tokens_to_boxes(tokens, original_size):
+ while (pair := find_delimiters_pair(tokens, TOKEN_BBOX_OPEN_STRING, TOKEN_BBOX_CLOSE_STRING)) != (
+ None,
+ None,
+ ):
+ start, end = pair
+ if end != start + 5:
+ continue
+
+ # Retrieve transformed coordinates from tokens
+ coords = self.tokenizer.convert_ids_to_tokens(tokens[start + 1 : end])
+
+ # Scale back to original image size and multiply by 2
+ scale = scale_factor_to_fit(original_size)
+ top, left, bottom, right = [2 * int(float(c) / scale) for c in coords]
+
+ # Replace the IDs so they get detokenized right
+ replacement = f" {TEXT_REPR_BBOX_OPEN}{top}, {left}, {bottom}, {right}{TEXT_REPR_BBOX_CLOSE}"
+ replacement = self.tokenizer.tokenize(replacement)[1:]
+ replacement = self.tokenizer.convert_tokens_to_ids(replacement)
+ replacement = torch.tensor(replacement).to(tokens)
+
+ tokens = torch.cat([tokens[:start], replacement, tokens[end + 1 :]], 0)
+ return tokens
+
+ def tokens_to_points(tokens, original_size):
+ while (pair := find_delimiters_pair(tokens, TOKEN_POINT_OPEN_STRING, TOKEN_POINT_CLOSE_STRING)) != (
+ None,
+ None,
+ ):
+ start, end = pair
+ if end != start + 3:
+ continue
+
+ # Retrieve transformed coordinates from tokens
+ coords = self.tokenizer.convert_ids_to_tokens(tokens[start + 1 : end])
+
+ # Scale back to original image size and multiply by 2
+ scale = scale_factor_to_fit(original_size)
+ x, y = [2 * int(float(c) / scale) for c in coords]
+
+ # Replace the IDs so they get detokenized right
+ replacement = f" {TEXT_REPR_POINT_OPEN}{x}, {y}{TEXT_REPR_POINT_CLOSE}"
+ replacement = self.tokenizer.tokenize(replacement)[1:]
+ replacement = self.tokenizer.convert_tokens_to_ids(replacement)
+ replacement = torch.tensor(replacement).to(tokens)
+
+ tokens = torch.cat([tokens[:start], replacement, tokens[end + 1 :]], 0)
+ return tokens
+
+ if target_sizes is None:
+ target_sizes = ((self.image_processor.size["height"], self.image_processor.size["width"]),) * len(outputs)
+ elif target_sizes.shape[1] != 2:
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
+
+ if len(outputs) != len(target_sizes):
+ raise ValueError("Make sure that you pass in as many target sizes as output sequences")
+
+ results = []
+ for seq, size in zip(outputs, target_sizes):
+ seq = tokens_to_boxes(seq, size)
+ seq = tokens_to_points(seq, size)
+ results.append(seq)
+
+ return results
+
+ def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
+ """
+ Post-processes the output of `FuyuForConditionalGeneration` to only return the text output.
+
+ Args:
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
+ The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
+ containing the token ids of the generated sequences.
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
+ **kwargs:
+ Additional arguments to be passed to the tokenizer's `batch_decode method`.
+
+ Returns:
+ `list[str]`: The decoded text output.
+ """
+ beginning_of_answer = self.tokenizer.convert_tokens_to_ids(BEGINNING_OF_ANSWER_STRING)
+ # get boa index for each outputted sequence tensor
+ # start all generated sequences from the beginning of the answer token, pad to have consistent length
+ unpadded_output_sequences = [
+ seq[(seq == beginning_of_answer).nonzero(as_tuple=True)[0] + 1 :] for seq in generated_outputs
+ ]
+ max_len = max(len(seq) for seq in unpadded_output_sequences)
+ # convert to torch and pad sequences
+ padded_output_sequences = torch.full((len(unpadded_output_sequences), max_len), self.pad_token_id)
+ for i, seq in enumerate(unpadded_output_sequences):
+ padded_output_sequences[i, : len(seq)] = torch.tensor(seq)
+
+ return self.batch_decode(padded_output_sequences, skip_special_tokens=skip_special_tokens, **kwargs)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+
+ # Make a copy of list when removing otherwise `self.image_processor.model_input_names` is also modified
+ extra_image_inputs = [
+ "image_input_ids",
+ "image_patch_indices_per_subsequence",
+ "images",
+ "image_patch_indices_per_batch",
+ ]
+ image_processor_input_names = [name for name in image_processor_input_names if name not in extra_image_inputs]
+ return list(tokenizer_input_names + image_processor_input_names + ["image_patches_indices"])
+
+
+__all__ = ["FuyuProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e92a8a2b9c9761d39526ccc0c12c26604fa2a49
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_glm4 import *
+ from .modeling_glm4 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4/configuration_glm4.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4/configuration_glm4.py
new file mode 100644
index 0000000000000000000000000000000000000000..46dc929826e4c37be0717600ade76d3f03be53b7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4/configuration_glm4.py
@@ -0,0 +1,152 @@
+# coding=utf-8
+# Copyright 2025 The GLM4 & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+
+
+class Glm4Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Glm4Model`]. It is used to instantiate an Glm4
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Glm4-4-9b-chat.
+ e.g. [THUDM/GLM-4-9B-0414](https://huggingface.co/THUDM/GLM-4-9B-0414)
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ vocab_size (`int`, *optional*, defaults to 151552):
+ Vocabulary size of the Glm4 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Glm4Model`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 13696):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 40):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 2):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5): The factor of the partial rotary position.
+ head_dim (`int`, *optional*, defaults to 128):
+ The attention head dimension.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The legacy activation function. It is overwritten by the `hidden_activation`.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1.5625e-07):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ pad_token_id (`int`, *optional*, defaults to 151329):
+ Padding token id.
+ eos_token_id (`int` | `list`, *optional*, defaults to `[151329, 151336, 151338]`):
+ End of stream token id.
+ bos_token_id (`int`, *optional*):
+ Beginning of stream token id.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `True`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ ```python
+ >>> from transformers import Glm4Model, Glm4Config
+ >>> # Initializing a Glm4 glm4-4-9b-chat style configuration
+ >>> configuration = Glm4Config()
+ >>> # Initializing a model from the glm4-4-9b-chat style configuration
+ >>> model = Glm4Model(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "glm4"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
+ "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=151552,
+ hidden_size=4096,
+ intermediate_size=13696,
+ num_hidden_layers=40,
+ num_attention_heads=32,
+ num_key_value_heads=2,
+ partial_rotary_factor=0.5,
+ head_dim=128,
+ hidden_act="silu",
+ attention_dropout=0.0,
+ max_position_embeddings=131072,
+ initializer_range=0.02,
+ rms_norm_eps=0.00000015625,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ pad_token_id=151329,
+ eos_token_id=[151329, 151336, 151338],
+ bos_token_id=None,
+ attention_bias=True,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.partial_rotary_factor = partial_rotary_factor
+ self.head_dim = head_dim
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+__all__ = ["Glm4Config"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4/modeling_glm4.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4/modeling_glm4.py
new file mode 100644
index 0000000000000000000000000000000000000000..dafab297f5667610284121a2bfe818bfcc12b3df
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4/modeling_glm4.py
@@ -0,0 +1,521 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/glm4/modular_glm4.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_glm4.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The GLM4 & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import (
+ GenericForSequenceClassification,
+ GenericForTokenClassification,
+ GradientCheckpointingLayer,
+)
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_glm4 import Glm4Config
+
+
+class Glm4MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ up_states = self.gate_up_proj(hidden_states)
+
+ gate, up_states = up_states.chunk(2, dim=-1)
+ up_states = up_states * self.activation_fn(gate)
+
+ return self.down_proj(up_states)
+
+
+class Glm4DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Glm4Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = Glm4Attention(config=config, layer_idx=layer_idx)
+
+ self.mlp = Glm4MLP(config)
+ self.input_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_self_attn_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_mlp_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.post_self_attn_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.post_mlp_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., 0::2]
+ x2 = x[..., 1::2]
+ return torch.stack((-x2, x1), dim=-1).flatten(-2)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ # Interleave them instead of usual shape
+ cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
+ sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
+
+ # Keep half or full tensor for later concatenation
+ rotary_dim = cos.shape[-1]
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ # Apply rotary embeddings on the first half or full tensor
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ # Concatenate back to full shape
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
+ return q_embed, k_embed
+
+
+class Glm4Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Glm4Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class Glm4RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Glm4RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Glm4RotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Glm4Config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@auto_docstring
+class Glm4PreTrainedModel(PreTrainedModel):
+ config: Glm4Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Glm4DecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": Glm4DecoderLayer,
+ "attentions": Glm4Attention,
+ }
+
+
+@auto_docstring
+class Glm4Model(Glm4PreTrainedModel):
+ def __init__(self, config: Glm4Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Glm4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Glm4RotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class Glm4ForCausalLM(Glm4PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Glm4Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Glm4ForCausalLM
+
+ >>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414")
+ >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class Glm4ForSequenceClassification(GenericForSequenceClassification, Glm4PreTrainedModel):
+ pass
+
+
+class Glm4ForTokenClassification(GenericForTokenClassification, Glm4PreTrainedModel):
+ pass
+
+
+__all__ = [
+ "Glm4PreTrainedModel",
+ "Glm4Model",
+ "Glm4ForCausalLM",
+ "Glm4ForSequenceClassification",
+ "Glm4ForTokenClassification",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4/modular_glm4.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4/modular_glm4.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bbc9b601f591d539e1bb529456a39337589f417
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4/modular_glm4.py
@@ -0,0 +1,139 @@
+# coding=utf-8
+# Copyright 2025 The GLM4 & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Union
+
+import torch
+
+from ...cache_utils import Cache
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import CausalLMOutputWithPast
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, logging
+from ...utils.deprecation import deprecate_kwarg
+from ..glm.modeling_glm import GlmAttention, GlmForCausalLM, GlmForSequenceClassification, GlmForTokenClassification
+from ..phi3.modeling_phi3 import Phi3MLP
+from .configuration_glm4 import Glm4Config
+from .modeling_glm4 import Glm4RMSNorm
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "THUDM/GLM-4-9B-0414"
+
+
+class Glm4MLP(Phi3MLP):
+ pass
+
+
+class Glm4DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Glm4Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = Glm4Attention(config=config, layer_idx=layer_idx)
+
+ self.mlp = Glm4MLP(config)
+ self.input_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_self_attn_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_mlp_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.post_self_attn_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.post_mlp_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Glm4Attention(GlmAttention):
+ pass
+
+
+class Glm4ForCausalLM(GlmForCausalLM):
+ def forward(
+ self,
+ **super_kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Glm4ForCausalLM
+
+ >>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414")
+ >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ return super().forward(**super_kwargs)
+
+
+class Glm4ForSequenceClassification(GlmForSequenceClassification):
+ pass
+
+
+class Glm4ForTokenClassification(GlmForTokenClassification):
+ pass
+
+
+__all__ = [
+ "Glm4PreTrainedModel", # noqa: F822
+ "Glm4Model", # noqa: F822
+ "Glm4ForCausalLM",
+ "Glm4ForSequenceClassification",
+ "Glm4ForTokenClassification",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4v_moe/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4v_moe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f99578a4be721ecdc5bcbd157fe75f8f16384086
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4v_moe/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_glm4v_moe import *
+ from .modeling_glm4v_moe import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4v_moe/configuration_glm4v_moe.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4v_moe/configuration_glm4v_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..b06642e250bcfedd82dcb7f47e2aae6d3f249dcb
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4v_moe/configuration_glm4v_moe.py
@@ -0,0 +1,384 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/glm4v_moe/modular_glm4v_moe.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_glm4v_moe.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class Glm4vMoeVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Glm4vMoeVisionModel`]. It is used to instantiate an Glm4vMoeVisionModel
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
+ a similar configuration to that of
+ GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking).
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 1536):
+ Dimensionality of the encoder layers and the pooler layer.
+ depth (`int`, *optional*, defaults to 24):
+ Number of layers (depth) in the model.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to add a bias to the queries, keys and values.
+ intermediate_size (`int`, *optional*, defaults to 13696):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"selu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ Dropout probability for attention weights.
+ projection_dropout (`float`, *optional*, defaults to 0.0):
+ Dropout probability for the projection layer.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ image_size (`int` or `list[int]`, *optional*, defaults to `[336, 336]`):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to `14`):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ out_hidden_size (`int`, *optional*, defaults to 4096):
+ The output hidden size of the vision model.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ spatial_merge_size (`int`, *optional*, defaults to 2):
+ The size used for merging spatial dimensions.
+ temporal_patch_size (`int`, *optional*, defaults to 2):
+ The size used for patches along the temporal dimension.
+ Example:
+
+ ```python
+ >>> from transformers import Glm4vMoeVisionConfig, Glm4vMoeVisionModel
+
+ >>> # Initializing a Glm4vMoeVisionConfig GLM-4.1V-9B style configuration
+ >>> configuration = Glm4vMoeVisionConfig()
+
+ >>> # Initializing a model (with random weights) from the GLM-4.1V-9B configuration
+ >>> model = Glm4vMoeVisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "glm4v_moe"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=24,
+ hidden_size=1536,
+ hidden_act="silu",
+ attention_bias=False,
+ attention_dropout=0.0,
+ num_heads=12,
+ in_channels=3,
+ image_size=336,
+ patch_size=14,
+ rms_norm_eps=1e-05,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ out_hidden_size=4096,
+ intermediate_size=13696,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.out_hidden_size = out_hidden_size
+ self.intermediate_size = intermediate_size
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+
+class Glm4vMoeTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Glm4vMoeModel`]. It is used to instantiate a
+ GLM-4.5V model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of
+ GLM-4.5V [zai-org/GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 151424):
+ Vocabulary size of the Glm4vMoe model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Glm4vMoeModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 10944):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 46):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 96):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5): The factor of the partial rotary position.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 65536):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ attention_bias (`bool`, defaults to `True`, *optional*, defaults to `True`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ moe_intermediate_size (`int`, *optional*, defaults to 1408):
+ Intermediate size of the routed expert.
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
+ number of experts per token.
+ n_shared_experts (`int`, *optional*, defaults to 1):
+ Number of shared experts.
+ n_routed_experts (`int`, *optional*, defaults to 128):
+ Number of routed experts.
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor or routed experts.
+ n_group (`int`, *optional*, defaults to 1):
+ Number of groups for routed experts.
+ topk_group (`int`, *optional*, defaults to 1):
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
+ first_k_dense_replace (`int`, *optional*, defaults to 1):
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
+ \--k dense layers--/
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the topk probabilities.
+
+ ```python
+ >>> from transformers import Glm4vMoeTextModel, Glm4vMoeConfig
+
+ >>> # Initializing a GLM-4.5V style configuration
+ >>> configuration = Glm4vMoeConfig()
+
+ >>> # Initializing a model from the GLM-4.5V style configuration
+ >>> model = Glm4vMoeTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "Glm4vMoe_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `Glm4vMoe`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
+ "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+ base_config_key = "text_config"
+
+ def __init__(
+ self,
+ vocab_size=151424,
+ hidden_size=4096,
+ intermediate_size=10944,
+ num_hidden_layers=46,
+ num_attention_heads=96,
+ partial_rotary_factor=0.5,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=65536,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=True,
+ attention_dropout=0.0,
+ moe_intermediate_size=1408,
+ num_experts_per_tok=8,
+ n_shared_experts=1,
+ n_routed_experts=128,
+ routed_scaling_factor=1.0,
+ n_group=1,
+ topk_group=1,
+ first_k_dense_replace=1,
+ norm_topk_prob=True,
+ **kwargs,
+ ):
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.partial_rotary_factor = partial_rotary_factor
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self, ignore_keys={"mrope_section"})
+
+ # MoE arguments
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_experts_per_tok = num_experts_per_tok
+ self.n_group = n_group
+ self.topk_group = topk_group
+ self.n_shared_experts = n_shared_experts
+ self.n_routed_experts = n_routed_experts
+ self.routed_scaling_factor = routed_scaling_factor
+ self.first_k_dense_replace = first_k_dense_replace
+ self.norm_topk_prob = norm_topk_prob
+
+
+class Glm4vMoeConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Glm4vMoeModel`]. It is used to instantiate a
+ GLM-4.5V model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of
+ GLM-4.5V [zai-org/GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vMoeTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vMoeVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 151363):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 151364):
+ The video token index to encode the image prompt.
+ image_start_token_id (`int`, *optional*, defaults to 151339):
+ The image start token index to encode the start of image.
+ image_end_token_id (`int`, *optional*, defaults to 151340):
+ The image end token index to encode the end of image.
+ video_start_token_id (`int`, *optional*, defaults to 151341):
+ The video start token index to encode the start of video.
+ video_end_token_id (`int`, *optional*, defaults to 151342):
+ The video end token index to encode the end of video.
+
+ ```python
+ >>> from transformers import Glm4vMoeForConditionalGeneration, Glm4vMoeConfig
+
+ >>> # Initializing a GLM-4.5V style configuration
+ >>> configuration = Glm4vMoeConfig()
+
+ >>> # Initializing a model from the GLM-4.5V style configuration
+ >>> model = Glm4vMoeForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "glm4v_moe"
+ sub_configs = {"vision_config": Glm4vMoeVisionConfig, "text_config": Glm4vMoeTextConfig}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=151363,
+ video_token_id=151364,
+ image_start_token_id=151339,
+ image_end_token_id=151340,
+ video_start_token_id=151341,
+ video_end_token_id=151342,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if isinstance(vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+
+ if isinstance(text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ self.text_config = self.sub_configs["text_config"](**kwargs)
+
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.video_start_token_id = video_start_token_id
+ self.video_end_token_id = video_end_token_id
+ self.image_start_token_id = image_start_token_id
+ self.image_end_token_id = image_end_token_id
+
+
+__all__ = ["Glm4vMoeConfig", "Glm4vMoeTextConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4v_moe/modeling_glm4v_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..045e78df5233becd5d44bd38ccb3745db51de407
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4v_moe/modeling_glm4v_moe.py
@@ -0,0 +1,1752 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/glm4v_moe/modular_glm4v_moe.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_glm4v_moe.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import itertools
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import LayerNorm
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_glm4v_moe import Glm4vMoeConfig, Glm4vMoeTextConfig, Glm4vMoeVisionConfig
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class Glm4vMoeRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Glm4vMoeRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
+
+ Explanation:
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
+ difference with modern LLMs.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ mrope_section(`List(int)`):
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ mrope_section = mrope_section * 2
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+
+ # Keep half or full tensor for later concatenation
+ rotary_dim = cos.shape[-1]
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ # Apply rotary embeddings on the first half or full tensor
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ # Concatenate back to full shape
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
+
+ return q_embed, k_embed
+
+
+class Glm4vMoeTextAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Glm4vMoeTextConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rope_scaling = config.rope_scaling
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
+ )
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Glm4vMoeTextTopkRouter(nn.Module):
+ def __init__(self, config: Glm4vMoeTextConfig):
+ super().__init__()
+ self.config = config
+ self.top_k = config.num_experts_per_tok
+ self.n_routed_experts = config.n_routed_experts
+ self.routed_scaling_factor = config.routed_scaling_factor
+ self.n_group = config.n_group
+ self.topk_group = config.topk_group
+ self.norm_topk_prob = config.norm_topk_prob
+
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
+ self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32))
+
+ @torch.no_grad()
+ def get_topk_indices(self, scores):
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
+ group_scores = (
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
+ .topk(2, dim=-1)[0]
+ .sum(dim=-1)
+ )
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
+ group_mask = torch.zeros_like(group_scores)
+ group_mask.scatter_(1, group_idx, 1)
+ score_mask = (
+ group_mask.unsqueeze(-1)
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
+ .reshape(-1, self.n_routed_experts)
+ )
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
+ return topk_indices
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
+ scores = router_logits.sigmoid()
+ topk_indices = self.get_topk_indices(scores)
+ topk_weights = scores.gather(1, topk_indices)
+ if self.norm_topk_prob:
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
+ topk_weights /= denominator
+ topk_weights = topk_weights * self.routed_scaling_factor
+ return topk_indices, topk_weights
+
+
+class Glm4vMoeTextMoE(nn.Module):
+ """
+ A mixed expert module containing shared experts.
+ """
+
+ def __init__(self, config: Glm4vMoeTextConfig):
+ super().__init__()
+ self.config = config
+ self.experts = nn.ModuleList(
+ [
+ Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size)
+ for _ in range(config.n_routed_experts)
+ ]
+ )
+ self.gate = Glm4vMoeTextTopkRouter(config)
+ self.shared_experts = Glm4vMoeTextMLP(
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
+ )
+
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
+ r"""
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
+ """
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
+ expert_mask = expert_mask.permute(2, 0, 1)
+
+ for expert_idx in range(len(self.experts)):
+ expert = self.experts[expert_idx]
+ mask = expert_mask[expert_idx]
+ token_indices, weight_indices = torch.where(mask)
+
+ if token_indices.numel() > 0:
+ expert_weights = topk_weights[token_indices, weight_indices]
+ expert_input = hidden_states[token_indices]
+ expert_output = expert(expert_input)
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
+
+ # in original deepseek, the output of the experts are gathered once we leave this module
+ # thus the moe module is itelsf an IsolatedParallel module
+ # and all expert are "local" meaning we shard but we don't gather
+ return final_hidden_states.type(hidden_states.dtype)
+
+ def forward(self, hidden_states):
+ residuals = hidden_states
+ orig_shape = hidden_states.shape
+ topk_indices, topk_weights = self.gate(hidden_states)
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
+ hidden_states = hidden_states + self.shared_experts(residuals)
+ return hidden_states
+
+
+class Glm4vMoeTextMLP(nn.Module):
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
+
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class Glm4vMoeTextRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Glm4vMoeTextRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Glm4vMoeTextDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Glm4vMoeTextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = Glm4vMoeTextAttention(config=config, layer_idx=layer_idx)
+
+ if layer_idx >= config.first_k_dense_replace:
+ self.mlp = Glm4vMoeTextMoE(config)
+ else:
+ self.mlp = Glm4vMoeTextMLP(config)
+
+ self.input_layernorm = Glm4vMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Glm4vMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class Glm4vMoePreTrainedModel(PreTrainedModel):
+ config: Glm4vMoeConfig
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Glm4vMoeTextDecoderLayer", "Glm4vMoeVisionBlock"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _can_compile_fullgraph = False
+ _supports_attention_backend = True
+
+ _can_record_outputs = {
+ "hidden_states": Glm4vMoeTextDecoderLayer,
+ "attentions": Glm4vMoeTextAttention,
+ }
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, Glm4vMoeTextTopkRouter):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+
+class Glm4vMoeisionMlp(nn.Module):
+ def __init__(self, config, bias: bool = False):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.out_hidden_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+class Glm4vMoeVisionPatchEmbed(nn.Module):
+ def __init__(self, config: Glm4vMoeVisionConfig) -> None:
+ super().__init__()
+ self.patch_size = config.patch_size
+ self.temporal_patch_size = config.temporal_patch_size
+ self.in_channels = config.in_channels
+ self.embed_dim = config.hidden_size
+
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
+ self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class Glm4vMoeVisionRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+class Glm4vMoeVisionPatchMerger(nn.Module):
+ def __init__(self, dim: int, context_dim: int, hidden_act: str, bias: bool = False) -> None:
+ super().__init__()
+ self.proj = nn.Linear(dim, dim, bias=bias)
+ self.post_projection_norm = LayerNorm(dim)
+ self.gate_proj = nn.Linear(dim, context_dim, bias=bias)
+ self.up_proj = nn.Linear(dim, context_dim, bias=bias)
+ self.down_proj = nn.Linear(context_dim, dim, bias=bias)
+ self.act1 = nn.GELU()
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.proj(hidden_state)
+ hidden_state = self.act1(self.post_projection_norm(hidden_state))
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+class Glm4vMoeVisionEmbeddings(nn.Module):
+ def __init__(self, config: Glm4vMoeVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+
+ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
+ """
+ Forward pass with integrated position encoding adaptation using 2D interpolation.
+
+ Args:
+ embeddings: Input embeddings tensor
+ lengths (torch.Tensor): Sequence lengths for each image in the batch.
+ image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w).
+ h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch.
+ w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch.
+
+ Returns:
+ torch.Tensor: Embeddings with adapted position encoding added.
+ """
+ # Get position embedding parameters
+ pos_embed_weight = self.position_embedding.weight
+ hidden_size = pos_embed_weight.shape[1]
+ total_seq = h_coords.shape[0]
+ device = pos_embed_weight.device
+
+ # Move coordinates to correct device
+ h_coords, w_coords = h_coords.to(device), w_coords.to(device)
+
+ # Handle empty sequence case
+ if total_seq == 0:
+ adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype)
+ else:
+ # Convert inputs to tensors if needed
+ if isinstance(lengths, list):
+ lengths = torch.tensor(lengths, device=device, dtype=torch.long)
+ if not isinstance(image_shapes, torch.Tensor):
+ image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long)
+
+ # Prepare 2D position embedding
+ orig_size_sq = pos_embed_weight.shape[0]
+ orig_size = int(orig_size_sq**0.5)
+ pos_embed_2d = (
+ pos_embed_weight.view(orig_size, orig_size, hidden_size)
+ .permute(2, 0, 1)
+ .unsqueeze(0)
+ .to(device=device, dtype=torch.float32)
+ )
+
+ # Calculate target dimensions for each patch
+ target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
+ device=device, dtype=torch.float32
+ )
+ target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
+ device=device, dtype=torch.float32
+ )
+
+ # Normalize coordinates to [-1, 1] range for grid_sample
+ h_coords = h_coords.to(device=device, dtype=torch.float32)
+ w_coords = w_coords.to(device=device, dtype=torch.float32)
+ norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
+ norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
+
+ # Create sampling grid
+ grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
+
+ # Perform bicubic interpolation
+ interpolated_embed_fp32 = F.grid_sample(
+ pos_embed_2d, grid, mode="bicubic", align_corners=False, padding_mode="border"
+ )
+
+ # Reshape and convert back to original dtype
+ adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
+ adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
+
+ # Add adapted position encoding to embeddings
+ embeddings = embeddings + adapted_pos_embed
+ return embeddings
+
+
+def apply_rotary_pos_emb_vision(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
+ orig_q_dtype = q.dtype
+ orig_k_dtype = k.dtype
+ q, k = q.float(), k.float()
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ q_embed = q_embed.to(orig_q_dtype)
+ k_embed = k_embed.to(orig_k_dtype)
+ return q_embed, k_embed
+
+
+class Glm4vMoeVisionAttention(nn.Module):
+ def __init__(self, config: Glm4vMoeVisionConfig) -> None:
+ super().__init__()
+ self.dim = config.hidden_size
+ self.num_heads = config.num_heads
+ self.head_dim = self.dim // self.num_heads
+ self.num_key_value_groups = 1 # needed for eager attention
+ self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ self.scaling = self.head_dim**-0.5
+ self.config = config
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ query_states, key_states, value_states = (
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ )
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
+
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ if self.config._attn_implementation == "flash_attention_2":
+ # Flash Attention 2: Use cu_seqlens for variable length attention
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+ attn_output, _ = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ cu_seq_lens_q=cu_seqlens,
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+ else:
+ # Other implementations: Process each chunk separately
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+ splits = [
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
+ ]
+
+ attn_outputs = [
+ attention_interface(
+ self,
+ q,
+ k,
+ v,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ is_causal=False,
+ **kwargs,
+ )[0]
+ for q, k, v in zip(*splits)
+ ]
+ attn_output = torch.cat(attn_outputs, dim=1)
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Glm4vMoeVisionBlock(GradientCheckpointingLayer):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.norm1 = Glm4vMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.norm2 = Glm4vMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.attn = Glm4vMoeVisionAttention(config)
+ self.mlp = Glm4vMoeisionMlp(config, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Glm4vMoeTextRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Glm4vMoeTextConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, Glm4vMoeText has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class Glm4vMoeModelOutputWithPast(ModelOutput):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+class Glm4vMoeVisionModel(Glm4vMoePreTrainedModel):
+ config: Glm4vMoeVisionConfig
+ _no_split_modules = ["Glm4vMoeVisionBlock"]
+
+ def __init__(self, config) -> None:
+ super().__init__(config)
+ self.spatial_merge_size = config.spatial_merge_size
+ self.patch_size = config.patch_size
+
+ self.embeddings = Glm4vMoeVisionEmbeddings(config)
+ self.patch_embed = Glm4vMoeVisionPatchEmbed(config)
+
+ head_dim = config.hidden_size // config.num_heads
+ self.rotary_pos_emb = Glm4vMoeVisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList([Glm4vMoeVisionBlock(config) for _ in range(config.depth)])
+ self.merger = Glm4vMoeVisionPatchMerger(
+ dim=config.out_hidden_size, context_dim=config.intermediate_size, hidden_act=config.hidden_act
+ )
+
+ self.post_conv_layernorm = Glm4vMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.downsample = nn.Conv2d(
+ in_channels=config.hidden_size,
+ out_channels=config.out_hidden_size,
+ kernel_size=config.spatial_merge_size,
+ stride=config.spatial_merge_size,
+ )
+ self.post_layernorm = Glm4vMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ self.post_init()
+
+ def rot_pos_emb(self, grid_thw):
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb, pos_ids
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
+ The final hidden states of the model.
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Returns:
+ `torch.Tensor`: hidden_states.
+ """
+ hidden_states = self.patch_embed(hidden_states)
+ hidden_states = self.post_conv_layernorm(hidden_states)
+
+ rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
+
+ for blk in self.blocks:
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states = self.post_layernorm(hidden_states)
+
+ hidden_states = hidden_states.view(
+ -1, self.spatial_merge_size, self.spatial_merge_size, hidden_states.shape[-1]
+ )
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
+ hidden_states = self.downsample(hidden_states).view(-1, self.config.out_hidden_size)
+
+ hidden_states = self.merger(hidden_states)
+ return hidden_states
+
+
+@auto_docstring
+class Glm4vMoeTextModel(Glm4vMoePreTrainedModel):
+ config: Glm4vMoeTextConfig
+
+ def __init__(self, config: Glm4vMoeTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Glm4vMoeTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Glm4vMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Glm4vMoeTextRotaryEmbedding(config=config)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ @check_model_inputs
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ # torch.jit.trace() doesn't support cache objects in the output
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # the hard coded `3` is for temporal, height and width.
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.dim() == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = layer_outputs
+
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class Glm4vMoeModel(Glm4vMoePreTrainedModel):
+ base_model_prefix = ""
+ _checkpoint_conversion_mapping = {}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: Glm4vMoeConfig
+ _no_split_modules = ["Glm4vMoeTextDecoderLayer", "Glm4vMoeVisionBlock"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = Glm4vMoeVisionModel._from_config(config.vision_config)
+ self.language_model = Glm4vMoeTextModel._from_config(config.text_config)
+ self.rope_deltas = None # cache rope_deltas here
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_rope_index(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
+
+ Explanation:
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
+
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
+ Examples:
+ input_ids: [T T T T T], here T is for text.
+ temporal position_ids: [0, 1, 2, 3, 4]
+ height position_ids: [0, 1, 2, 3, 4]
+ width position_ids: [0, 1, 2, 3, 4]
+
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
+ and 1D rotary position embedding for text part.
+ Examples:
+ Temporal (Time): 3 patches, representing different segments of the video in time.
+ Height: 2 patches, dividing each frame vertically.
+ Width: 2 patches, dividing each frame horizontally.
+ We also have some important parameters:
+ fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
+ tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
+ temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
+ interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
+ vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
+ text temporal position_ids: [101, 102, 103, 104, 105]
+ text height position_ids: [101, 102, 103, 104, 105]
+ text width position_ids: [101, 102, 103, 104, 105]
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ Returns:
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
+ """
+
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_start_token_id = self.config.video_start_token_id
+ video_end_token_id = self.config.video_end_token_id
+
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_index, video_index = 0, 0
+ video_group_index = 0
+ attention_mask = attention_mask.to(total_input_ids.device)
+ for i, input_ids in enumerate(total_input_ids):
+ input_ids = input_ids[attention_mask[i] == 1]
+ input_tokens = input_ids.tolist()
+
+ input_token_type = []
+ video_check_flg = False
+ for token in input_tokens:
+ if token == video_start_token_id:
+ video_check_flg = True
+ elif token == video_end_token_id:
+ video_check_flg = False
+
+ if token == image_token_id and not video_check_flg:
+ input_token_type.append("image")
+ elif token == image_token_id and video_check_flg:
+ input_token_type.append("video")
+ else:
+ input_token_type.append("text")
+
+ input_type_group = []
+ for key, group in itertools.groupby(enumerate(input_token_type), lambda x: x[1]):
+ group = list(group)
+ start_index = group[0][0]
+ end_index = group[-1][0] + 1
+ input_type_group.append((key, start_index, end_index))
+
+ llm_pos_ids_list = []
+ video_frame_num = 1
+ for modality_type, start_idx, end_idx in input_type_group:
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+
+ if modality_type == "image":
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
+
+ image_index += 1
+ video_frame_num = 1
+
+ elif modality_type == "video":
+ t, h, w = (
+ video_frame_num,
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t,
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+
+ for t_idx in range(llm_grid_t):
+ t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
+
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
+
+ video_group_index += 1
+
+ if video_group_index >= video_grid_thw[video_index][0]:
+ video_index += 1
+ video_group_index = 0
+
+ video_frame_num += 1
+
+ else:
+ text_len = end_idx - start_idx
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ video_frame_num = 1
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+ else:
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
+ else:
+ position_ids = (
+ torch.arange(input_ids.shape[1], device=input_ids.device)
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ """
+ Encodes videos into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
+ # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
+ temp_frames_hw = []
+ for t, h, w in video_grid_thw:
+ repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
+ temp_frames_hw.append(repeated_row)
+ flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
+ video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
+ split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ video_embeds = torch.split(video_embeds, split_sizes)
+ return video_embeds
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ pixel_values = pixel_values.type(self.visual.dtype)
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ image_embeds = torch.split(image_embeds, split_sizes)
+ return image_embeds
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: Optional[torch.FloatTensor] = None,
+ video_features: Optional[torch.FloatTensor] = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ else:
+ # GLM-4.1V and GLM-4.5V special_video_mask is special_image_mask
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
+ raise ValueError(
+ f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
+ )
+
+ return special_image_mask, special_video_mask
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Glm4vMoeModelOutputWithPast]:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw)
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds)
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds)
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ if position_ids is None:
+ attention_mask_tensor = (
+ attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
+ )
+ if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
+ attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
+ # Only apply conversion for floating point tensors (inverted masks)
+ if attention_mask_tensor.dtype.is_floating_point:
+ attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
+ attention_mask_tensor = (1.0 - attention_mask_tensor).int()
+
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # When compiling, we can't check tensor values thus we check only input length
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
+ # models currently cannot do asssisted decoding
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
+ (input_ids is not None and input_ids.shape[1] != 1)
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
+ )
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
+ (cache_position is not None and cache_position[0] == 0)
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
+ )
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ video_grid_thw,
+ attention_mask=attention_mask_tensor,
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ delta = (
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
+ if cache_position is not None
+ else 0
+ )
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ if cache_position is not None: # otherwise `deltas` is an int `0`
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return Glm4vMoeModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Glm4vMoe causal language model (or autoregressive) outputs.
+ """
+)
+class Glm4vMoeCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+class Glm4vMoeForConditionalGeneration(Glm4vMoePreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+ _tied_weights_keys = ["lm_head.weight"]
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Glm4vMoeModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw)
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ return self.model.get_image_features(pixel_values, image_grid_thw)
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def visual(self):
+ return self.model.visual
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Glm4vMoeCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Glm4vMoeForConditionalGeneration
+
+ >>> model = Glm4vMoeForConditionalGeneration.from_pretrained("THUDM/GLM-4.1V-9B-Thinking")
+ >>> processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": "What is shown in this image?"},
+ ],
+ },
+ ]
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return Glm4vMoeCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ # GLM-4.1V position_ids are prepareed with rope_deltas in forward
+ model_inputs["position_ids"] = None
+
+ if cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_values_videos"] = None
+
+ return model_inputs
+
+ def _get_image_nums_and_video_nums(
+ self,
+ input_ids: Optional[torch.LongTensor],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Returns:
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
+ """
+
+ if inputs_embeds is not None:
+ is_image = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ is_video_start = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ is_video_end = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ else:
+ is_image = input_ids == self.config.image_start_token_id
+ is_video_start = input_ids == self.config.video_start_token_id
+ is_video_end = input_ids == self.config.video_end_token_id
+
+ # Cumulative sum to track if we're inside a video span
+ # We'll assume well-formed video tags (i.e. matching starts and ends)
+ video_level = torch.cumsum(is_video_start.int() - is_video_end.int(), dim=1)
+ inside_video = video_level > 0 # shape (batch_size, seq_length)
+
+ # Mask out image tokens that are inside video spans
+ standalone_images = is_image & (~inside_video)
+
+ # Count per batch
+ image_counts = standalone_images.sum(dim=1)
+ video_counts = is_video_start.sum(dim=1)
+
+ return image_counts, video_counts
+
+ def _expand_inputs_for_generation(
+ self,
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: Optional[torch.LongTensor] = None,
+ **model_kwargs,
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
+ # Overwritten -- Support for expanding tensors without a batch size dimension
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
+ # image_grid_thw.shape[0] is sum(num_images for samples)
+
+ if expand_size == 1:
+ return input_ids, model_kwargs
+
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
+
+ def _expand_dict_for_generation_visual(dict_to_expand):
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
+ )
+
+ def _repeat_interleave_samples(x, lengths, repeat_times):
+ samples = torch.split(x, lengths)
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
+ return result
+
+ for key in dict_to_expand:
+ if key == "pixel_values":
+ # split images into samples
+ samples = torch.split(image_grid_thw, list(image_nums))
+ # compute the sequence length of images for each sample
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "image_grid_thw":
+ # get the num of images for each sample
+ lengths = list(image_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "pixel_values_videos":
+ samples = torch.split(video_grid_thw, list(video_nums))
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "video_grid_thw":
+ lengths = list(video_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "second_per_grid_ts":
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
+ )
+ return dict_to_expand
+
+ def _expand_dict_for_generation(dict_to_expand):
+ for key in dict_to_expand:
+ if (
+ key != "cache_position"
+ and dict_to_expand[key] is not None
+ and isinstance(dict_to_expand[key], torch.Tensor)
+ and key not in visual_keys
+ ):
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
+ return dict_to_expand
+
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
+
+ if is_encoder_decoder:
+ if model_kwargs.get("encoder_outputs") is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
+
+ return input_ids, model_kwargs
+
+
+__all__ = ["Glm4vMoeForConditionalGeneration", "Glm4vMoeModel", "Glm4vMoePreTrainedModel", "Glm4vMoeTextModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4v_moe/modular_glm4v_moe.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4v_moe/modular_glm4v_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dfe28ff19da878a689afb0ab6621e8cfb35f340
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/glm4v_moe/modular_glm4v_moe.py
@@ -0,0 +1,459 @@
+# coding=utf-8
+# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional
+
+import torch
+import torch.nn as nn
+
+from ...cache_utils import Cache
+from ...configuration_utils import PretrainedConfig
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_rope_utils import rope_config_validation
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import logging
+from ..glm4.modeling_glm4 import Glm4Attention
+from ..glm4_moe.configuration_glm4_moe import Glm4MoeConfig
+from ..glm4_moe.modeling_glm4_moe import (
+ Glm4MoeDecoderLayer,
+ Glm4MoeMLP,
+ Glm4MoeMoE,
+ Glm4MoePreTrainedModel,
+ Glm4MoeRMSNorm,
+ Glm4MoeTopkRouter,
+ eager_attention_forward,
+)
+from ..glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
+from ..glm4v.modeling_glm4v import (
+ Glm4vForConditionalGeneration,
+ rotate_half,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class Glm4vMoeVisionConfig(Glm4vVisionConfig):
+ pass
+
+
+class Glm4vMoeTextConfig(Glm4MoeConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Glm4vMoeModel`]. It is used to instantiate a
+ GLM-4.5V model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of
+ GLM-4.5V [zai-org/GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 151424):
+ Vocabulary size of the Glm4vMoe model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Glm4vMoeModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 10944):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 46):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 96):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5): The factor of the partial rotary position.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 65536):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ attention_bias (`bool`, defaults to `True`, *optional*, defaults to `True`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ moe_intermediate_size (`int`, *optional*, defaults to 1408):
+ Intermediate size of the routed expert.
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
+ number of experts per token.
+ n_shared_experts (`int`, *optional*, defaults to 1):
+ Number of shared experts.
+ n_routed_experts (`int`, *optional*, defaults to 128):
+ Number of routed experts.
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor or routed experts.
+ n_group (`int`, *optional*, defaults to 1):
+ Number of groups for routed experts.
+ topk_group (`int`, *optional*, defaults to 1):
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
+ first_k_dense_replace (`int`, *optional*, defaults to 1):
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
+ \--k dense layers--/
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the topk probabilities.
+
+ ```python
+ >>> from transformers import Glm4vMoeTextModel, Glm4vMoeConfig
+
+ >>> # Initializing a GLM-4.5V style configuration
+ >>> configuration = Glm4vMoeConfig()
+
+ >>> # Initializing a model from the GLM-4.5V style configuration
+ >>> model = Glm4vMoeTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "Glm4vMoe_text"
+ base_config_key = "text_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `Glm4vMoe`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
+ "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=151424,
+ hidden_size=4096,
+ intermediate_size=10944,
+ num_hidden_layers=46,
+ num_attention_heads=96,
+ partial_rotary_factor=0.5,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=65536,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=True,
+ attention_dropout=0.0,
+ moe_intermediate_size=1408,
+ num_experts_per_tok=8,
+ n_shared_experts=1,
+ n_routed_experts=128,
+ routed_scaling_factor=1.0,
+ n_group=1,
+ topk_group=1,
+ first_k_dense_replace=1,
+ norm_topk_prob=True,
+ **kwargs,
+ ):
+ PretrainedConfig.__init__(self, tie_word_embeddings=tie_word_embeddings, **kwargs)
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.partial_rotary_factor = partial_rotary_factor
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self, ignore_keys={"mrope_section"})
+
+ # MoE arguments
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_experts_per_tok = num_experts_per_tok
+ self.n_group = n_group
+ self.topk_group = topk_group
+ self.n_shared_experts = n_shared_experts
+ self.n_routed_experts = n_routed_experts
+ self.routed_scaling_factor = routed_scaling_factor
+ self.first_k_dense_replace = first_k_dense_replace
+ self.norm_topk_prob = norm_topk_prob
+
+
+class Glm4vMoeConfig(Glm4vConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Glm4vMoeModel`]. It is used to instantiate a
+ GLM-4.5V model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of
+ GLM-4.5V [zai-org/GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vMoeTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vMoeVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 151363):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 151364):
+ The video token index to encode the image prompt.
+ image_start_token_id (`int`, *optional*, defaults to 151339):
+ The image start token index to encode the start of image.
+ image_end_token_id (`int`, *optional*, defaults to 151340):
+ The image end token index to encode the end of image.
+ video_start_token_id (`int`, *optional*, defaults to 151341):
+ The video start token index to encode the start of video.
+ video_end_token_id (`int`, *optional*, defaults to 151342):
+ The video end token index to encode the end of video.
+
+ ```python
+ >>> from transformers import Glm4vMoeForConditionalGeneration, Glm4vMoeConfig
+
+ >>> # Initializing a GLM-4.5V style configuration
+ >>> configuration = Glm4vMoeConfig()
+
+ >>> # Initializing a model from the GLM-4.5V style configuration
+ >>> model = Glm4vMoeForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=151363,
+ video_token_id=151364,
+ image_start_token_id=151339,
+ image_end_token_id=151340,
+ video_start_token_id=151341,
+ video_end_token_id=151342,
+ **kwargs,
+ ):
+ super().__init__()
+
+
+class Glm4vMoeRMSNorm(Glm4MoeRMSNorm):
+ pass
+
+
+def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
+
+ Explanation:
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
+ difference with modern LLMs.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ mrope_section(`List(int)`):
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ mrope_section = mrope_section * 2
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+
+ # Keep half or full tensor for later concatenation
+ rotary_dim = cos.shape[-1]
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ # Apply rotary embeddings on the first half or full tensor
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ # Concatenate back to full shape
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
+
+ return q_embed, k_embed
+
+
+class Glm4vMoeTextAttention(Glm4Attention):
+ def __init__(self, config: Glm4vMoeTextConfig, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+ self.rope_scaling = config.rope_scaling
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
+ )
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Glm4vMoeTextTopkRouter(Glm4MoeTopkRouter, nn.Module):
+ def __init__(self, config: Glm4vMoeTextConfig):
+ super().__init__(config)
+
+
+class Glm4vMoeTextMoE(Glm4MoeMoE):
+ def __init__(self, config: Glm4vMoeTextConfig):
+ super().__init__(config)
+ self.config = config
+ self.experts = nn.ModuleList(
+ [
+ Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size)
+ for _ in range(config.n_routed_experts)
+ ]
+ )
+ self.gate = Glm4vMoeTextTopkRouter(config)
+ self.shared_experts = Glm4vMoeTextMLP(
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
+ )
+
+
+class Glm4vMoeTextMLP(Glm4MoeMLP):
+ pass
+
+
+class Glm4vMoeTextDecoderLayer(Glm4MoeDecoderLayer):
+ def __init__(self, config: Glm4vMoeTextConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+
+
+class Glm4vMoePreTrainedModel(Glm4MoePreTrainedModel):
+ config: Glm4vMoeConfig
+ base_model_prefix = ""
+ _no_split_modules = ["Glm4vMoeTextDecoderLayer", "Glm4vMoeVisionBlock"]
+ _skip_keys_device_placement = "past_key_values"
+
+ _can_record_outputs = {
+ "hidden_states": Glm4vMoeTextDecoderLayer,
+ "attentions": Glm4vMoeTextAttention,
+ }
+
+
+class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
+ pass
+
+
+__all__ = [
+ "Glm4vMoeConfig",
+ "Glm4vMoeTextConfig",
+ "Glm4vMoeForConditionalGeneration",
+ "Glm4vMoeModel", # noqa: F822
+ "Glm4vMoePreTrainedModel",
+ "Glm4vMoeTextModel", # noqa: F822
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..00b6ccc53fc0efb0fc88c2f95586276cd40010fe
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_got_ocr2 import *
+ from .image_processing_got_ocr2 import *
+ from .image_processing_got_ocr2_fast import *
+ from .modeling_got_ocr2 import *
+ from .processing_got_ocr2 import *
+
+
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/configuration_got_ocr2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/configuration_got_ocr2.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb039f958950c3fbc3cd4401b425348181b13a00
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/configuration_got_ocr2.py
@@ -0,0 +1,211 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_got_ocr2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...configuration_utils import PretrainedConfig
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+class GotOcr2VisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GotOcr2VisionModel`]. It is used to instantiate a GOT_OCR2
+ vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
+ defaults will yield a similar configuration to that of the SAM ViT-h
+ [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ output_channels (`int`, *optional*, defaults to 256):
+ Dimensionality of the output channels in the Patch Encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input image.
+ image_size (`int`, *optional*, defaults to 1024):
+ Expected resolution. Target size of the resized input image.
+ patch_size (`int`, *optional*, defaults to 16):
+ Size of the patches to be extracted from the input image.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string)
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 1e-10):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to query, key, value projections.
+ use_abs_pos (`bool`, *optional*, defaults to `True`):
+ Whether to use absolute position embedding.
+ use_rel_pos (`bool`, *optional*, defaults to `True`):
+ Whether to use relative position embedding.
+ window_size (`int`, *optional*, defaults to 14):
+ Window size for relative position.
+ global_attn_indexes (`list[int]`, *optional*, defaults to `[2, 5, 8, 11]`):
+ The indexes of the global attention layers.
+ mlp_dim (`int`, *optional*, defaults to 3072):
+ The dimensionality of the MLP layer in the Transformer encoder.
+ """
+
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ output_channels=256,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ num_channels=3,
+ image_size=1024,
+ patch_size=16,
+ hidden_act="gelu",
+ layer_norm_eps=1e-06,
+ attention_dropout=0.0,
+ initializer_range=1e-10,
+ qkv_bias=True,
+ use_abs_pos=True,
+ use_rel_pos=True,
+ window_size=14,
+ global_attn_indexes=[2, 5, 8, 11],
+ mlp_dim=3072,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.output_channels = output_channels
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.qkv_bias = qkv_bias
+ self.use_abs_pos = use_abs_pos
+ self.use_rel_pos = use_rel_pos
+ self.window_size = window_size
+ self.global_attn_indexes = global_attn_indexes
+ self.mlp_dim = mlp_dim
+
+
+class GotOcr2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GotOcr2ForConditionalGeneration`]. It is used to instantiate a
+ GotOcr2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of GOT-OCR-2.0.
+
+ e.g [stepfun-ai/GOT-OCR-2.0-hf](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
+ The config object or dictionary of the text backbone.
+ image_token_index (`int`, *optional*, defaults to 151859):
+ The image token index to encode the image prompt.
+ image_seq_length (`int`, *optional*, defaults to 576):
+ Sequence length of one image embedding.
+ pad_token_id (`int`, *optional*, defaults to -1):
+ Padding token id.
+
+ ```python
+ >>> from transformers import GotOcr2ForConditionalGeneration, GotOcr2Config
+
+ >>> # Initializing a GotOcr2 style configuration
+ >>> configuration = GotOcr2Config()
+
+ >>> # Initializing a model from the Qwen2-VL-7B style configuration
+ >>> model = GotOcr2ForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "got_ocr2"
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ }
+ sub_configs = {"text_config": AutoConfig, "vision_config": GotOcr2VisionConfig}
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ image_token_index=151859,
+ image_seq_length=576,
+ pad_token_id=-1,
+ **kwargs,
+ ):
+ self.image_token_index = image_token_index
+ self.image_seq_length = image_seq_length
+ self.pad_token_id = pad_token_id
+
+ if vision_config is None:
+ self.vision_config = GotOcr2VisionConfig()
+ elif isinstance(vision_config, dict):
+ self.vision_config = GotOcr2VisionConfig(**vision_config)
+ elif isinstance(vision_config, GotOcr2VisionConfig):
+ self.vision_config = vision_config
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "qwen2")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ text_config = CONFIG_MAPPING["qwen2"](
+ vocab_size=151860,
+ hidden_size=1024,
+ intermediate_size=2816,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=True,
+ rope_theta=1000000.0,
+ rope_scaling=None,
+ use_sliding_window=False,
+ sliding_window=4096,
+ max_window_layers=21,
+ attention_dropout=0.0,
+ )
+
+ self.text_config = text_config
+
+ super().__init__(**kwargs)
+
+
+__all__ = ["GotOcr2VisionConfig", "GotOcr2Config"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/image_processing_got_ocr2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/image_processing_got_ocr2.py
new file mode 100644
index 0000000000000000000000000000000000000000..209ac88ea2fbecc5ddd2c4d82b92830992caac18
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/image_processing_got_ocr2.py
@@ -0,0 +1,526 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Got-OCR-2."""
+
+from functools import lru_cache
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ convert_to_rgb,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+# Similar to image_processing_mllama.get_all_supported_aspect_ratios
+@lru_cache(maxsize=10)
+def get_all_supported_aspect_ratios(min_image_tiles: int, max_image_tiles: int) -> list[tuple[int, int]]:
+ """
+ Computes all allowed aspect ratios for a given minimum and maximum number of input tiles.
+
+ This function calculates all possible arrangements of tiles that can be formed
+ within the constraint of the minimum and maximum number of tiles. Each arrangement is
+ represented by its aspect ratio (width/height) and the corresponding tile configuration.
+
+ Args:
+ min_image_tiles (`int`):
+ The minimum number of tiles allowed.
+ max_image_tiles (`int`):
+ The maximum number of tiles allowed.
+
+ Returns:
+ `list[tuple[int, int]]`: A list of tuples, each tuple representing a valid (width, height)
+ configuration in terms of number of tiles.
+
+ Example:
+ >>> get_all_supported_aspect_ratios(1, 4)
+ [(1, 1), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (2, 2), (4, 1)]
+
+ """
+ aspect_ratios = []
+ for width in range(1, max_image_tiles + 1):
+ for height in range(1, max_image_tiles + 1):
+ if width * height <= max_image_tiles and width * height >= min_image_tiles:
+ aspect_ratios.append((width, height))
+
+ aspect_ratios = sorted(aspect_ratios, key=lambda x: x[0] * x[1])
+
+ return aspect_ratios
+
+
+@lru_cache(maxsize=100)
+def get_optimal_tiled_canvas(
+ original_image_size: tuple[int, int],
+ target_tile_size: tuple[int, int],
+ min_image_tiles: int,
+ max_image_tiles: int,
+) -> tuple[int, int]:
+ """
+ Given a minimum and maximum number of tiles, find the canvas with the closest aspect ratio to the
+ original image aspect ratio.
+ In case of tie-breaking condition when two canvases have the same aspect ratio difference, we favor the canvas with
+ more tiles, until the area covered by the tiles is more than twice the target area, in order to avoid unnecessarily
+ excessive tiling.
+ """
+ possible_tile_arrangements = get_all_supported_aspect_ratios(min_image_tiles, max_image_tiles)
+
+ original_height, original_width = original_image_size
+ target_tile_height, target_tile_width = target_tile_size
+ aspect_ratio = original_width / original_height
+ area = original_width * original_height
+
+ # find the grid with the best aspect ratio
+ best_ratio_diff = float("inf")
+ best_grid = (1, 1)
+ for grid in possible_tile_arrangements:
+ grid_aspect_ratio = grid[0] / grid[1]
+ ratio_diff = abs(aspect_ratio - grid_aspect_ratio)
+ if ratio_diff < best_ratio_diff:
+ best_ratio_diff = ratio_diff
+ best_grid = grid
+ elif ratio_diff == best_ratio_diff:
+ # if the aspect ratio difference is the same, we favor the grid with more patches
+ # until the area covered by the patches is more than twice the original image area
+ if area > 0.5 * target_tile_height * target_tile_width * grid[0] * grid[1]:
+ best_grid = grid
+
+ return best_grid
+
+
+class GotOcr2ImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a GOT_OCR2 image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ crop_to_patches (`bool`, *optional*, defaults to `False`):
+ Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
+ `preprocess` method.
+ min_patches (`int`, *optional*, defaults to 1):
+ The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
+ max_patches (`int`, *optional*, defaults to 12):
+ The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
+ overridden by the `resample` parameter in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+ `do_rescale` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ crop_to_patches: bool = False,
+ min_patches: int = 1,
+ max_patches: int = 12,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 384, "width": 384}
+ size = get_size_dict(size, default_to_square=True)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.crop_to_patches = crop_to_patches
+ self.min_patches = min_patches
+ self.max_patches = max_patches
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+ self.do_convert_rgb = do_convert_rgb
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
+ output_size = (size["height"], size["width"])
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ crop_to_patches: Optional[bool] = None,
+ min_patches: Optional[int] = None,
+ max_patches: Optional[int] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Controls the size of the image after `resize`. The shortest edge of the image is resized to
+ `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
+ is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
+ edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
+ crop_to_patches (`bool`, *optional*, defaults to `self.crop_to_patches`):
+ Whether to crop the image to patches.
+ min_patches (`int`, *optional*, defaults to `self.min_patches`):
+ The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`.
+ max_patches (`int`, *optional*, defaults to `self.max_patches`):
+ The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to normalize the image by if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ crop_to_patches = crop_to_patches if crop_to_patches is not None else self.crop_to_patches
+ min_patches = min_patches if min_patches is not None else self.min_patches
+ max_patches = max_patches if max_patches is not None else self.max_patches
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+ # PIL RGBA images are converted to RGB
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if crop_to_patches and max_patches > 1:
+ images = [
+ self.crop_image_to_patches(
+ image,
+ min_patches=min_patches,
+ max_patches=max_patches,
+ patch_size=size,
+ data_format=input_data_format,
+ )
+ for image in images
+ ]
+ num_patches = np.array([len(image) for image in images])
+ images = [image for images_list in images for image in images_list]
+ else:
+ num_patches = np.array([1] * len(images))
+
+ for i, image in enumerate(images):
+ if do_resize:
+ images[i] = self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_rescale:
+ images[i] = self.rescale(image=images[i], scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ images[i] = self.normalize(
+ image=images[i],
+ mean=image_mean,
+ std=image_std,
+ input_data_format=input_data_format,
+ )
+
+ images[i] = to_channel_dimension_format(images[i], data_format, input_channel_dim=input_data_format)
+
+ encoded_outputs = BatchFeature(
+ data={"pixel_values": images, "num_patches": num_patches}, tensor_type=return_tensors
+ )
+
+ return encoded_outputs
+
+ def crop_image_to_patches(
+ self,
+ images: np.ndarray,
+ min_patches: int,
+ max_patches: int,
+ use_thumbnail: bool = True,
+ patch_size: Optional[Union[tuple, int, dict]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ ):
+ """
+ Crop the image to patches and return a list of cropped images.
+ The number of patches and their grid arrangement are determined by the original image size,
+ the target patch size and the minimum and maximum number of patches.
+ The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
+
+ Args:
+ images (`np.ndarray`):
+ The image to be cropped.
+ min_patches (`int`):
+ The minimum number of patches to be extracted from the image.
+ max_patches (`int`):
+ The maximum number of patches to be extracted from the image.
+ use_thumbnail (`bool`, *optional*, defaults to `True`):
+ Whether to add a thumbnail image to the list of cropped patches.
+ patch_size (`int`, `tuple[int, int]`, `dict`, *optional*):
+ The size of the output patches.
+ data_format (`ChannelDimension`, *optional*):
+ The format of the image data. If `None`, the format is inferred from the input image.
+
+ Returns:
+ list[`PIL.Image.Image`] or list[np.ndarray]: The list of cropped images.
+ """
+ if data_format is None:
+ data_format = infer_channel_dimension_format(images)
+ images = to_channel_dimension_format(images, ChannelDimension.FIRST, data_format)
+ patch_size_height, patch_size_width = patch_size["height"], patch_size["width"]
+ original_height, original_width = images.shape[-2:]
+ # find the closest aspect ratio to the target
+ num_columns, num_rows = get_optimal_tiled_canvas(
+ (original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches
+ )
+
+ # calculate the target width and height
+ target_width = patch_size_width * num_columns
+ target_height = patch_size_height * num_rows
+ num_blocks = num_columns * num_rows
+
+ # resize the image so that each patch is of patch_size
+ resized_image = self.resize(
+ images,
+ {"height": target_height, "width": target_width},
+ data_format=ChannelDimension.FIRST,
+ input_data_format=ChannelDimension.FIRST,
+ )
+ # split the image into patches
+ processed_images = []
+ for i in range(num_blocks):
+ column = i % num_columns
+ row = i // num_columns
+ box = (
+ column * patch_size_width,
+ row * patch_size_height,
+ (column + 1) * patch_size_width,
+ (row + 1) * patch_size_height,
+ )
+ # split the image
+ patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]]
+ patch_image = to_channel_dimension_format(patch_image, data_format, ChannelDimension.FIRST)
+ processed_images.append(patch_image)
+
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = self.resize(
+ images, patch_size, data_format=data_format, input_data_format=ChannelDimension.FIRST
+ )
+ processed_images.append(thumbnail_img)
+
+ return processed_images
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number patches for a given image size.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*)
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of patches per image.
+ """
+ min_patches = images_kwargs.get("min_patches", self.min_patches)
+ max_patches = images_kwargs.get("max_patches", self.max_patches)
+ patch_size = images_kwargs.get("patch_size", self.size)
+ crop_to_patches = images_kwargs.get("crop_to_patches", self.crop_to_patches)
+
+ num_patches = 1
+ if crop_to_patches and max_patches > 1:
+ num_columns, num_rows = get_optimal_tiled_canvas(
+ (height, width), (patch_size["height"], patch_size["width"]), min_patches, max_patches
+ )
+ if num_columns * num_rows > 1:
+ num_patches += num_columns * num_rows
+
+ return num_patches
+
+
+__all__ = ["GotOcr2ImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..a47a1422a5dc5dcee6c9d75cb7436a1180c32829
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py
@@ -0,0 +1,247 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for Got-OCR-2."""
+
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ImageInput, PILImageResampling, SizeDict
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ auto_docstring,
+)
+from .image_processing_got_ocr2 import get_optimal_tiled_canvas
+
+
+class GotOcr2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ crop_to_patches (`bool`, *optional*, defaults to `False`):
+ Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
+ `preprocess` method.
+ min_patches (`int`, *optional*, defaults to 1):
+ The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
+ max_patches (`int`, *optional*, defaults to 12):
+ The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
+ """
+
+ crop_to_patches: Optional[bool]
+ min_patches: Optional[int]
+ max_patches: Optional[int]
+
+
+@auto_docstring
+class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ size = {"height": 384, "width": 384}
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+ crop_to_patches = False
+ min_patches = 1
+ max_patches = 12
+ valid_kwargs = GotOcr2FastImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[GotOcr2FastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[GotOcr2FastImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def crop_image_to_patches(
+ self,
+ images: "torch.Tensor",
+ min_patches: int,
+ max_patches: int,
+ use_thumbnail: bool = True,
+ patch_size: Optional[Union[tuple, int, dict]] = None,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ ):
+ """
+ Crop the images to patches and return a list of cropped images.
+ The number of patches and their grid arrangement are determined by the original image size,
+ the target patch size and the minimum and maximum number of patches.
+ The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
+
+ Args:
+ images (`torch.Tensor`):
+ The images to be cropped.
+ min_patches (`int`):
+ The minimum number of patches to be extracted from the image.
+ max_patches (`int`):
+ The maximum number of patches to be extracted from the image.
+ use_thumbnail (`bool`, *optional*, defaults to `True`):
+ Whether to add a thumbnail image to the list of cropped patches.
+ patch_size (`int`, `tuple[int, int]`, `dict`, *optional*):
+ The size of the output patches.
+ The format of the image data. If `None`, the format is inferred from the input image.
+
+ Returns:
+ list[`PIL.Image.Image`] or list[np.ndarray]: The list of cropped images.
+ """
+ patch_size_height, patch_size_width = patch_size.height, patch_size.width
+ original_height, original_width = images.shape[-2:]
+ # find the closest aspect ratio to the target
+ num_columns, num_rows = get_optimal_tiled_canvas(
+ (original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches
+ )
+
+ # calculate the target width and height
+ target_width = patch_size_width * num_columns
+ target_height = patch_size_height * num_rows
+ num_blocks = num_columns * num_rows
+
+ # resize the image so that each patch is of patch_size
+ resized_image = self.resize(
+ images, SizeDict(height=target_height, width=target_width), interpolation=interpolation
+ )
+ # split the image into patches
+ processed_images = []
+ for i in range(num_blocks):
+ column = i % num_columns
+ row = i // num_columns
+ box = (
+ column * patch_size_width,
+ row * patch_size_height,
+ (column + 1) * patch_size_width,
+ (row + 1) * patch_size_height,
+ )
+ # split the image
+ patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]]
+ processed_images.append(patch_image)
+
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = self.resize(images, patch_size, interpolation=interpolation)
+ processed_images.append(thumbnail_img)
+
+ processed_images = torch.stack(processed_images, dim=0).transpose(0, 1).contiguous()
+
+ return processed_images
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ crop_to_patches: bool,
+ min_patches: int,
+ max_patches: int,
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ if crop_to_patches:
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ num_patches = {}
+ for shape, stacked_images in grouped_images.items():
+ stacked_images = self.crop_image_to_patches(
+ stacked_images,
+ min_patches,
+ max_patches,
+ patch_size=size,
+ interpolation=interpolation,
+ )
+ processed_images_grouped[shape] = stacked_images
+ num_patches[shape] = [stacked_images.shape[1]] * stacked_images.shape[0]
+ images = reorder_images(processed_images_grouped, grouped_images_index)
+ images = [image for images_list in images for image in images_list]
+ num_patches = reorder_images(num_patches, grouped_images_index)
+ else:
+ num_patches = [1] * len(images)
+
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_center_crop:
+ stacked_images = self.center_crop(stacked_images, crop_size)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(
+ data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors
+ )
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number patches for a given image size.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*)
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of patches per image.
+ """
+ min_patches = images_kwargs.get("min_patches", self.min_patches)
+ max_patches = images_kwargs.get("max_patches", self.max_patches)
+ patch_size = images_kwargs.get("patch_size", self.size)
+ crop_to_patches = images_kwargs.get("crop_to_patches", self.crop_to_patches)
+
+ num_patches = 1
+ if crop_to_patches and max_patches > 1:
+ num_columns, num_rows = get_optimal_tiled_canvas(
+ (height, width), (patch_size["height"], patch_size["width"]), min_patches, max_patches
+ )
+ if num_columns * num_rows > 1:
+ num_patches += num_columns * num_rows
+
+ return num_patches
+
+
+__all__ = ["GotOcr2ImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/modeling_got_ocr2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/modeling_got_ocr2.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb22e62bfd7df1f56b8c5b13ac6fcacb90df6696
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/modeling_got_ocr2.py
@@ -0,0 +1,840 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_got_ocr2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import collections
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from transformers.utils.generic import check_model_inputs
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ..auto import AutoModel
+from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig
+
+
+class GotOcr2MLPBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
+ self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
+ self.act = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.lin1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.lin2(hidden_states)
+ return hidden_states
+
+
+class GotOcr2VisionAttention(nn.Module):
+ """Multi-head Attention block with relative position embeddings."""
+
+ def __init__(self, config, window_size):
+ super().__init__()
+ input_size = (
+ (config.image_size // config.patch_size, config.image_size // config.patch_size)
+ if window_size == 0
+ else (window_size, window_size)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ head_dim = config.hidden_size // config.num_attention_heads
+ self.scale = head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size)
+
+ self.use_rel_pos = config.use_rel_pos
+ if self.use_rel_pos:
+ if input_size is None:
+ raise ValueError("Input size must be provided if using relative positional encoding.")
+
+ # initialize relative positional embeddings
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+
+ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+ """
+ Get relative positional embeddings according to the relative positions of
+ query and key sizes.
+
+ Args:
+ q_size (int):
+ size of the query.
+ k_size (int):
+ size of key k.
+ rel_pos (`torch.Tensor`):
+ relative position embeddings (L, channel).
+
+ Returns:
+ Extracted positional embeddings according to relative positions.
+ """
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos.
+ rel_pos_resized = F.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+ size=max_rel_dist,
+ mode="linear",
+ )
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+ return rel_pos_resized[relative_coords.long()]
+
+ def get_decomposed_rel_pos(
+ self,
+ query: torch.Tensor,
+ rel_pos_h: torch.Tensor,
+ rel_pos_w: torch.Tensor,
+ q_size: tuple[int, int],
+ k_size: tuple[int, int],
+ ) -> torch.Tensor:
+ """
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
+
+ Args:
+ query (`torch.Tensor`):
+ query q in the attention layer with shape (batch_size, query_height * query_width, channel).
+ rel_pos_h (`torch.Tensor`):
+ relative position embeddings (Lh, channel) for height axis.
+ rel_pos_w (`torch.Tensor`):
+ relative position embeddings (Lw, channel) for width axis.
+ q_size (tuple):
+ spatial sequence size of query q with (query_height, query_width).
+ k_size (tuple):
+ spatial sequence size of key k with (key_height, key_width).
+
+ Returns:
+ decomposed_rel_pos (`torch.Tensor`):
+ decomposed relative position embeddings.
+ """
+ query_height, query_width = q_size
+ key_height, key_width = k_size
+ relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
+ relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
+
+ batch_size, _, dim = query.shape
+ reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
+ rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
+
+ decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
+
+ return decomposed_rel_pos
+
+ def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size, height, width, _ = hidden_states.shape
+ # qkv with shape (3, batch_size, nHead, height * width, channel)
+ qkv = (
+ self.qkv(hidden_states)
+ .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
+ .permute(2, 0, 3, 1, 4)
+ )
+ # q, k, v with shape (batch_size * nHead, height * width, channel)
+ query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
+
+ attn_weights = (query * self.scale) @ key.transpose(-2, -1)
+
+ if self.use_rel_pos:
+ decomposed_rel_pos = self.get_decomposed_rel_pos(
+ query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
+ )
+ decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
+ attn_weights = attn_weights + decomposed_rel_pos
+
+ attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
+ attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
+
+ attn_output = self.proj(attn_output)
+ return attn_output, attn_weights
+
+
+class GotOcr2VisionLayer(GradientCheckpointingLayer):
+ def __init__(self, config, window_size):
+ super().__init__()
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attn = GotOcr2VisionAttention(config, window_size)
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = GotOcr2MLPBlock(config)
+ self.window_size = window_size
+
+ def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> tuple[torch.Tensor, tuple[int, int]]:
+ """
+ Args:
+ Partition into non-overlapping windows with padding if needed.
+ hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window
+ size.
+
+ Returns:
+ windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel].
+ (pad_height, pad_width): padded height and width before partition
+ """
+ batch_size, height, width, channel = hidden_states.shape
+
+ pad_h = (window_size - height % window_size) % window_size
+ pad_w = (window_size - width % window_size) % window_size
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
+ pad_height, pad_width = height + pad_h, width + pad_w
+
+ hidden_states = hidden_states.reshape(
+ batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel
+ )
+ windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel)
+ return windows, (pad_height, pad_width)
+
+ def window_unpartition(
+ self, windows: torch.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int]
+ ) -> torch.Tensor:
+ """
+ Args:
+ Window unpartition into original sequences and removing padding.
+ hidden_states (tensor):
+ input tokens with [batch_size * num_windows, window_size, window_size, channel].
+ window_size (int):
+ window size.
+ padding_shape (Tuple):
+ padded height and width (pad_height, pad_width).
+ original_shape (Tuple): original height and width (height, width) before padding.
+
+ Returns:
+ hidden_states: unpartitioned sequences with [batch_size, height, width, channel].
+ """
+ pad_height, pad_width = padding_shape
+ height, width = original_shape
+ batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size)
+ hidden_states = windows.reshape(
+ batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1
+ )
+ hidden_states = (
+ hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
+ )
+
+ hidden_states = hidden_states[:, :height, :width, :].contiguous()
+ return hidden_states
+
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]:
+ residual = hidden_states
+ hidden_states = self.layer_norm1(hidden_states)
+ # Window partition
+ if self.window_size > 0:
+ height, width = hidden_states.shape[1], hidden_states.shape[2]
+ hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
+
+ hidden_states, attn_weights = self.attn(
+ hidden_states=hidden_states,
+ )
+ # Reverse window partition
+ if self.window_size > 0:
+ hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
+
+ hidden_states = residual + hidden_states
+ layernorm_output = self.layer_norm2(hidden_states)
+ hidden_states = hidden_states + self.mlp(layernorm_output)
+ return hidden_states
+
+
+@auto_docstring
+class GotOcr2PreTrainedModel(PreTrainedModel):
+ config: GotOcr2Config
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = False
+ _supports_sdpa = False
+
+ _can_compile_fullgraph = True
+ _supports_flex_attn = False
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, GotOcr2VisionAttention):
+ if module.use_rel_pos:
+ module.rel_pos_h.data.zero_()
+ module.rel_pos_w.data.zero_()
+ elif isinstance(module, GotOcr2VisionEncoder):
+ if module.pos_embed is not None:
+ module.pos_embed.data.zero_()
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for got_ocr2 vision model's outputs that also contains image embeddings obtained by applying the projection
+ layer to the pooler_output.
+ """
+)
+class GotOcr2VisionEncoderOutput(ModelOutput):
+ r"""
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to the pooler_output.
+ """
+
+ image_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+class GotOcr2PatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values):
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+ )
+ embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
+ return embeddings
+
+
+class GotOcr2LayerNorm(nn.LayerNorm):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
+ super().__init__(normalized_shape, eps=eps, **kwargs)
+ if data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError(f"Unsupported data format: {data_format}")
+ self.data_format = data_format
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
+ """
+ if self.data_format == "channels_first":
+ features = features.permute(0, 2, 3, 1)
+ features = super().forward(features)
+ features = features.permute(0, 3, 1, 2)
+ else:
+ features = super().forward(features)
+ return features
+
+
+class GotOcr2VisionNeck(nn.Module):
+ def __init__(self, config: GotOcr2VisionConfig):
+ super().__init__()
+ self.config = config
+
+ self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
+ self.layer_norm1 = GotOcr2LayerNorm(config.output_channels, data_format="channels_first")
+ self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False)
+ self.layer_norm2 = GotOcr2LayerNorm(config.output_channels, data_format="channels_first")
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = self.layer_norm1(hidden_states)
+
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.layer_norm2(hidden_states)
+ return hidden_states
+
+
+class GotOcr2VisionEncoder(GotOcr2PreTrainedModel):
+ _can_record_outputs = {"hidden_states": GotOcr2VisionLayer, "attentions": GotOcr2VisionAttention}
+
+ def __init__(self, config: GotOcr2VisionConfig):
+ super().__init__(config)
+ self.config = config
+ self.image_size = config.image_size
+ self.patch_embed = GotOcr2PatchEmbeddings(config)
+
+ self.pos_embed = None
+ if config.use_abs_pos:
+ # Initialize absolute positional embedding with pretrain image size.
+ self.pos_embed = nn.Parameter(
+ torch.zeros(
+ 1,
+ config.image_size // config.patch_size,
+ config.image_size // config.patch_size,
+ config.hidden_size,
+ )
+ )
+
+ self.layers = nn.ModuleList()
+ for i in range(config.num_hidden_layers):
+ layer = GotOcr2VisionLayer(
+ config,
+ window_size=config.window_size if i not in config.global_attn_indexes else 0,
+ )
+ self.layers.append(layer)
+
+ self.neck = GotOcr2VisionNeck(config)
+
+ self.gradient_checkpointing = False
+
+ def get_input_embeddings(self):
+ return self.patch_embed
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ def forward(
+ self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs]
+ ) -> GotOcr2VisionEncoderOutput:
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.patch_embed(pixel_values)
+ if self.pos_embed is not None:
+ hidden_states = hidden_states + self.pos_embed
+ for layer_module in self.layers:
+ hidden_states = layer_module(hidden_states)
+ hidden_states = self.neck(hidden_states)
+ return GotOcr2VisionEncoderOutput(
+ last_hidden_state=hidden_states,
+ )
+
+
+class GotOcr2MultiModalProjector(nn.Module):
+ def __init__(self, config: GotOcr2Config):
+ super().__init__()
+ vision_output_channels = config.vision_config.output_channels
+ language_hidden_size = config.text_config.hidden_size
+ self.conv_upsampler1 = nn.Conv2d(
+ vision_output_channels, vision_output_channels * 2, kernel_size=3, stride=2, padding=1, bias=False
+ )
+ self.conv_upsampler2 = nn.Conv2d(
+ vision_output_channels * 2, language_hidden_size, kernel_size=3, stride=2, padding=1, bias=False
+ )
+ self.multimodal_projector = nn.Linear(language_hidden_size, language_hidden_size)
+
+ def forward(self, vision_embeddings: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.conv_upsampler1(vision_embeddings)
+ hidden_state = self.conv_upsampler2(hidden_state)
+ hidden_state = hidden_state.flatten(2).permute(0, 2, 1)
+ hidden_state = self.multimodal_projector(hidden_state)
+ return hidden_state
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for GotOcr2 causal language model (or autoregressive) outputs.
+ """
+)
+class GotOcr2CausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for GotOcr2 outputs, with hidden states and attentions.
+ """
+)
+class GotOcr2ModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ The GotOcr2 model which consists of a vision backbone and a language model, without a language modeling head.
+ """
+)
+class GotOcr2Model(GotOcr2PreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
+ def __init__(self, config: GotOcr2Config):
+ super().__init__(config)
+ self.vision_tower = GotOcr2VisionEncoder(config.vision_config)
+
+ self.multi_modal_projector = GotOcr2MultiModalProjector(config)
+ self.language_model = AutoModel.from_config(config.text_config)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ image_outputs = self.vision_tower(pixel_values).last_hidden_state
+ return self.multi_modal_projector(image_outputs)
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ return special_image_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, GotOcr2ModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return GotOcr2ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The GOT_OCR2 model which consists of a vision backbone and a language model.
+ """
+)
+class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: GotOcr2Config):
+ super().__init__(config)
+ self.model = GotOcr2Model(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ **kwargs,
+ ):
+ return self.model.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ **kwargs,
+ )
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, GotOcr2CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, GotOcr2ForConditionalGeneration, TextStreamer
+
+ >>> model = GotOcr2ForConditionalGeneration.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf").to("cuda")
+ >>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
+
+ >>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(image, return_tensors="pt", color="green").to("cuda")
+
+ >>> # Generate
+ >>> streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
+ >>> generate_ids = model.generate(
+ ... **inputs,
+ ... do_sample=False,
+ ... tokenizer = processor.tokenizer,
+ ... stop_strings='<|im_end|>',
+ ... streamer=streamer,
+ ... max_new_tokens=4096,
+ ... )
+ "You should keep in mind what features from the module should be used, especially
+ when you're planning to sell a template."
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return GotOcr2CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+
+ return model_inputs
+
+
+__all__ = ["GotOcr2PreTrainedModel", "GotOcr2Model", "GotOcr2ForConditionalGeneration"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/modular_got_ocr2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/modular_got_ocr2.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ecf39fcd03b577b406868a639d3ce8ee9425e3d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/modular_got_ocr2.py
@@ -0,0 +1,483 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...cache_utils import Cache
+from ...configuration_utils import PretrainedConfig
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, can_return_tuple, logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+from ..llava.modeling_llava import (
+ LlavaCausalLMOutputWithPast,
+ LlavaForConditionalGeneration,
+ LlavaModel,
+ LlavaModelOutputWithPast,
+ LlavaPreTrainedModel,
+ TransformersKwargs,
+)
+from ..sam.modeling_sam import (
+ SamMLPBlock,
+ SamPreTrainedModel,
+ SamVisionAttention,
+ SamVisionEncoder,
+ SamVisionLayer,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class GotOcr2VisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GotOcr2VisionModel`]. It is used to instantiate a GOT_OCR2
+ vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
+ defaults will yield a similar configuration to that of the SAM ViT-h
+ [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ output_channels (`int`, *optional*, defaults to 256):
+ Dimensionality of the output channels in the Patch Encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input image.
+ image_size (`int`, *optional*, defaults to 1024):
+ Expected resolution. Target size of the resized input image.
+ patch_size (`int`, *optional*, defaults to 16):
+ Size of the patches to be extracted from the input image.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string)
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 1e-10):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to query, key, value projections.
+ use_abs_pos (`bool`, *optional*, defaults to `True`):
+ Whether to use absolute position embedding.
+ use_rel_pos (`bool`, *optional*, defaults to `True`):
+ Whether to use relative position embedding.
+ window_size (`int`, *optional*, defaults to 14):
+ Window size for relative position.
+ global_attn_indexes (`list[int]`, *optional*, defaults to `[2, 5, 8, 11]`):
+ The indexes of the global attention layers.
+ mlp_dim (`int`, *optional*, defaults to 3072):
+ The dimensionality of the MLP layer in the Transformer encoder.
+ """
+
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ output_channels=256,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ num_channels=3,
+ image_size=1024,
+ patch_size=16,
+ hidden_act="gelu",
+ layer_norm_eps=1e-06,
+ attention_dropout=0.0,
+ initializer_range=1e-10,
+ qkv_bias=True,
+ use_abs_pos=True,
+ use_rel_pos=True,
+ window_size=14,
+ global_attn_indexes=[2, 5, 8, 11],
+ mlp_dim=3072,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.output_channels = output_channels
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.qkv_bias = qkv_bias
+ self.use_abs_pos = use_abs_pos
+ self.use_rel_pos = use_rel_pos
+ self.window_size = window_size
+ self.global_attn_indexes = global_attn_indexes
+ self.mlp_dim = mlp_dim
+
+
+class GotOcr2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GotOcr2ForConditionalGeneration`]. It is used to instantiate a
+ GotOcr2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of GOT-OCR-2.0.
+
+ e.g [stepfun-ai/GOT-OCR-2.0-hf](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
+ The config object or dictionary of the text backbone.
+ image_token_index (`int`, *optional*, defaults to 151859):
+ The image token index to encode the image prompt.
+ image_seq_length (`int`, *optional*, defaults to 576):
+ Sequence length of one image embedding.
+ pad_token_id (`int`, *optional*, defaults to -1):
+ Padding token id.
+
+ ```python
+ >>> from transformers import GotOcr2ForConditionalGeneration, GotOcr2Config
+
+ >>> # Initializing a GotOcr2 style configuration
+ >>> configuration = GotOcr2Config()
+
+ >>> # Initializing a model from the Qwen2-VL-7B style configuration
+ >>> model = GotOcr2ForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "got_ocr2"
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ }
+ sub_configs = {"text_config": AutoConfig, "vision_config": GotOcr2VisionConfig}
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ image_token_index=151859,
+ image_seq_length=576,
+ pad_token_id=-1,
+ **kwargs,
+ ):
+ self.image_token_index = image_token_index
+ self.image_seq_length = image_seq_length
+ self.pad_token_id = pad_token_id
+
+ if vision_config is None:
+ self.vision_config = GotOcr2VisionConfig()
+ elif isinstance(vision_config, dict):
+ self.vision_config = GotOcr2VisionConfig(**vision_config)
+ elif isinstance(vision_config, GotOcr2VisionConfig):
+ self.vision_config = vision_config
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "qwen2")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ text_config = CONFIG_MAPPING["qwen2"](
+ vocab_size=151860,
+ hidden_size=1024,
+ intermediate_size=2816,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=True,
+ rope_theta=1000000.0,
+ rope_scaling=None,
+ use_sliding_window=False,
+ sliding_window=4096,
+ max_window_layers=21,
+ attention_dropout=0.0,
+ )
+
+ self.text_config = text_config
+
+ super().__init__(**kwargs)
+
+
+class GotOcr2MLPBlock(SamMLPBlock):
+ pass
+
+
+class GotOcr2VisionAttention(SamVisionAttention):
+ pass
+
+
+class GotOcr2VisionLayer(SamVisionLayer):
+ def __init__(self, config, window_size):
+ super().__init__(config, window_size)
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attn = GotOcr2VisionAttention(config, window_size)
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = GotOcr2MLPBlock(config)
+ self.window_size = window_size
+
+
+class GotOcr2PreTrainedModel(SamPreTrainedModel):
+ pass
+
+
+class GotOcr2VisionEncoder(SamVisionEncoder, GotOcr2PreTrainedModel):
+ pass
+
+
+class GotOcr2MultiModalProjector(nn.Module):
+ def __init__(self, config: GotOcr2Config):
+ super().__init__()
+ vision_output_channels = config.vision_config.output_channels
+ language_hidden_size = config.text_config.hidden_size
+ self.conv_upsampler1 = nn.Conv2d(
+ vision_output_channels, vision_output_channels * 2, kernel_size=3, stride=2, padding=1, bias=False
+ )
+ self.conv_upsampler2 = nn.Conv2d(
+ vision_output_channels * 2, language_hidden_size, kernel_size=3, stride=2, padding=1, bias=False
+ )
+ self.multimodal_projector = nn.Linear(language_hidden_size, language_hidden_size)
+
+ def forward(self, vision_embeddings: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.conv_upsampler1(vision_embeddings)
+ hidden_state = self.conv_upsampler2(hidden_state)
+ hidden_state = hidden_state.flatten(2).permute(0, 2, 1)
+ hidden_state = self.multimodal_projector(hidden_state)
+ return hidden_state
+
+
+class GotOcr2CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
+ pass
+
+
+class GotOcr2ModelOutputWithPast(LlavaModelOutputWithPast):
+ pass
+
+
+class GotOcr2PreTrainedModel(LlavaPreTrainedModel):
+ _supports_flash_attn = False
+ _supports_sdpa = False
+ _supports_flex_attn = False
+
+ def _init_weights(self, module):
+ PreTrainedModel._init_weights(self, module)
+ if isinstance(module, GotOcr2VisionAttention):
+ if module.use_rel_pos:
+ module.rel_pos_h.data.zero_()
+ module.rel_pos_w.data.zero_()
+ elif isinstance(module, GotOcr2VisionEncoder):
+ if module.pos_embed is not None:
+ module.pos_embed.data.zero_()
+
+
+class GotOcr2Model(LlavaModel):
+ def __init__(self, config: GotOcr2Config):
+ super().__init__(config)
+ self.vision_tower = GotOcr2VisionEncoder(config.vision_config)
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ image_outputs = self.vision_tower(pixel_values).last_hidden_state
+ return self.multi_modal_projector(image_outputs)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, GotOcr2ModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return GotOcr2ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, GotOcr2CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, GotOcr2ForConditionalGeneration, TextStreamer
+
+ >>> model = GotOcr2ForConditionalGeneration.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf").to("cuda")
+ >>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
+
+ >>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(image, return_tensors="pt", color="green").to("cuda")
+
+ >>> # Generate
+ >>> streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
+ >>> generate_ids = model.generate(
+ ... **inputs,
+ ... do_sample=False,
+ ... tokenizer = processor.tokenizer,
+ ... stop_strings='<|im_end|>',
+ ... streamer=streamer,
+ ... max_new_tokens=4096,
+ ... )
+ "You should keep in mind what features from the module should be used, especially
+ when you're planning to sell a template."
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return GotOcr2CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+
+__all__ = [
+ "GotOcr2VisionConfig",
+ "GotOcr2Config",
+ "GotOcr2PreTrainedModel",
+ "GotOcr2Model",
+ "GotOcr2ForConditionalGeneration",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/processing_got_ocr2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/processing_got_ocr2.py
new file mode 100644
index 0000000000000000000000000000000000000000..16c062ec63ade1310971f8797291439b59bad5ac
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/got_ocr2/processing_got_ocr2.py
@@ -0,0 +1,261 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional, Union
+
+import numpy as np
+
+from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
+from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+
+from ...image_processing_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...utils import is_vision_available, logging
+
+
+if is_vision_available():
+ from ...image_utils import load_images
+
+logger = logging.get_logger(__name__)
+
+
+class GotOcr2TextKwargs(TextKwargs, total=False):
+ format: Optional[bool]
+
+
+class GotOcr2ImagesKwargs(ImagesKwargs, total=False):
+ box: Optional[Union[list, tuple[float, float], tuple[float, float, float, float]]]
+ color: Optional[str]
+ num_image_tokens: Optional[int]
+ multi_page: Optional[bool]
+ crop_to_patches: Optional[bool]
+ min_patches: Optional[int]
+ max_patches: Optional[int]
+
+
+class GotOcr2ProcessorKwargs(ProcessingKwargs, total=False):
+ text_kwargs: GotOcr2TextKwargs
+ images_kwargs: GotOcr2ImagesKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "format": False,
+ },
+ "images_kwargs": {
+ "num_image_tokens": 256,
+ "multi_page": False,
+ "crop_to_patches": False,
+ "min_patches": 1,
+ "max_patches": 12,
+ },
+ }
+
+
+def preprocess_box_annotation(box: Union[list, tuple], image_size: tuple[int, int]) -> list:
+ """
+ Convert box annotation to the format [x1, y1, x2, y2] in the range [0, 1000].
+ """
+ width, height = image_size
+ if len(box) == 4:
+ box[0] = int(box[0] / width * 1000)
+ box[1] = int(box[1] / height * 1000)
+ box[2] = int(box[2] / width * 1000)
+ box[3] = int(box[3] / height * 1000)
+ else:
+ raise ValueError("Box must be a list or tuple of lists in the form [x1, y1, x2, y2].")
+
+ return list(box)
+
+
+class GotOcr2Processor(ProcessorMixin):
+ r"""
+ Constructs a GotOcr2 processor which wraps a [`GotOcr2ImageProcessor`] and
+ [`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and
+ tokenizer functionalities. See the [`~GotOcr2Processor.__call__`] and [`~GotOcr2Processor.decode`] for more information.
+ Args:
+ image_processor ([`GotOcr2ImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "PreTrainedTokenizerFast"
+
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ self.message_start_token = "<|im_start|>"
+ self.message_end_token = "<|im_end|>"
+ self.img_start_token = "
"
+ self.img_end_token = ""
+ self.img_pad_token = ""
+ self.image_token = "" # keep the above for BC, but we need to call it `image_token`
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
+ self.system_query = "system\nYou should follow the instructions carefully and explain your answers in detail."
+
+ def _make_list_of_inputs(self, images, text, box, color, multi_page):
+ if not isinstance(images, (list, tuple)):
+ images = [images]
+ if multi_page:
+ logger.warning("Multi-page inference is enabled but only one image is passed.")
+ images = [images]
+ elif isinstance(images[0], (list, tuple)) and not multi_page:
+ raise ValueError("Nested images are only supported with `multi_page` set to `True`.")
+ elif not isinstance(images[0], (list, tuple)) and multi_page:
+ images = [images]
+
+ if isinstance(text, str):
+ text = [text]
+
+ if not isinstance(box[0], (list, tuple)):
+ # Use the same box for all images
+ box = [box for _ in range(len(images))]
+ if not isinstance(color, (list, tuple)):
+ color = [color for _ in range(len(images))]
+
+ return images, text, box, color
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
+ audio=None,
+ videos=None,
+ **kwargs: Unpack[GotOcr2ProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text if `text`
+ is not `None`, otherwise encode default OCR queries which depends on the `format`, `box`, `color`, `multi_page` and
+ `crop_to_patches` arguments. To prepare the vision inputs, this method forwards the `images` and `kwargs` arguments to
+ GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ format (`bool`, *optional*):
+ If set, will add the format token to the query, and the model will return the OCR result with formatting.
+ box (`list[float]`, `list[tuple[float, float]]`, `list[tuple[float, float, float, float]]`, *optional*):
+ The box annotation to be added to the query. If a list of floats or a tuple of floats is provided, it
+ will be interpreted as [x1, y1, x2, y2]. If a list of tuples is provided, each tuple should be in the
+ form (x1, y1, x2, y2).
+ color (`str`, *optional*):
+ The color annotation to be added to the query. The model will return the OCR result within the box with
+ the specified color.
+ multi_page (`bool`, *optional*):
+ If set, will enable multi-page inference. The model will return the OCR result across multiple pages.
+ crop_to_patches (`bool`, *optional*):
+ If set, will crop the image to patches. The model will return the OCR result upon the patch reference.
+ min_patches (`int`, *optional*):
+ The minimum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to
+ `True`.
+ max_patches (`int`, *optional*):
+ The maximum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to
+ `True`.
+
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+
+ output_kwargs = self._merge_kwargs(
+ GotOcr2ProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ format_output = output_kwargs["text_kwargs"].pop("format")
+ num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens")
+ box = output_kwargs["images_kwargs"].pop("box", [None])
+ color = output_kwargs["images_kwargs"].pop("color", None)
+ multi_page = output_kwargs["images_kwargs"].pop("multi_page")
+
+ crop_to_patches = output_kwargs["images_kwargs"].get("crop_to_patches")
+ images, text, box, color = self._make_list_of_inputs(images, text, box, color, multi_page)
+ if multi_page:
+ # save the number of pages per batch
+ num_pages_per_batch = [len(image_group) for image_group in images]
+ # flatten the list of images
+ images = [image for image_group in images for image in image_group]
+ else:
+ num_pages_per_batch = [1 for _ in range(len(images))]
+ # Load images as we need to know the image size
+ images = load_images(images)
+ image_sizes = [image.size for image in images]
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ num_patches_array = image_inputs.pop("num_patches")
+ if text is None:
+ text = []
+ patch_indices = np.cumsum(num_pages_per_batch)
+ for index, (num_pages, box_single, color_single) in enumerate(zip(num_pages_per_batch, box, color)):
+ current_patch_index = patch_indices[index - 1] if index > 0 else 0
+ num_patches = sum(num_patches_array[current_patch_index : current_patch_index + num_pages])
+ if box_single[0] is not None:
+ box_single = preprocess_box_annotation(box_single, image_sizes[index])
+ query = (
+ f"{f'[{color_single}] ' if color_single is not None else ''}"
+ f"{str(box_single) if box_single[0] is not None else ''} "
+ "OCR"
+ f"{' with format' if format_output else ''}"
+ f"{' across multi pages' if multi_page else ''}"
+ f"{' upon the patch reference' if crop_to_patches else ''}"
+ ": "
+ )
+ prompt = (
+ self.message_start_token
+ + self.system_query
+ + self.message_end_token
+ + self.message_start_token
+ + "user\n"
+ + self.img_start_token
+ + self.img_pad_token * num_image_tokens * num_patches
+ + self.img_end_token
+ + "\n"
+ + query
+ + self.message_end_token
+ + self.message_start_token
+ + "assistant\n"
+ )
+ text.append(prompt)
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
+
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
+
+
+__all__ = ["GotOcr2Processor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f01899e668e3a86548db3f59c7f42d70746385ab
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_gpt2 import *
+ from .modeling_flax_gpt2 import *
+ from .modeling_gpt2 import *
+ from .modeling_tf_gpt2 import *
+ from .tokenization_gpt2 import *
+ from .tokenization_gpt2_fast import *
+ from .tokenization_gpt2_tf import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/configuration_gpt2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/configuration_gpt2.py
new file mode 100644
index 0000000000000000000000000000000000000000..db5151a2ba15635a7943744799b0689fc96790d3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/configuration_gpt2.py
@@ -0,0 +1,274 @@
+# coding=utf-8
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""OpenAI GPT-2 configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+from typing import Any, Optional
+
+from ... import PreTrainedTokenizer, TensorType, is_torch_available
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfigWithPast, PatchingSpec
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class GPT2Config(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to
+ instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the GPT-2
+ [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50257):
+ Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
+ n_positions (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ n_embd (`int`, *optional*, defaults to 768):
+ Dimensionality of the embeddings and hidden states.
+ n_layer (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ n_head (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ n_inner (`int`, *optional*):
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
+ activation_function (`str`, *optional*, defaults to `"gelu_new"`):
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ embd_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the embeddings.
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
+ The epsilon to use in the layer normalization layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ summary_type (`string`, *optional*, defaults to `"cls_index"`):
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
+ [`TFGPT2DoubleHeadsModel`].
+
+ Has to be one of the following options:
+
+ - `"last"`: Take the last token hidden state (like XLNet).
+ - `"first"`: Take the first token hidden state (like BERT).
+ - `"mean"`: Take the mean of all tokens hidden states.
+ - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
+ - `"attn"`: Not implemented now, use multi-head attention.
+ summary_use_proj (`bool`, *optional*, defaults to `True`):
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
+ [`TFGPT2DoubleHeadsModel`].
+
+ Whether or not to add a projection after the vector extraction.
+ summary_activation (`str`, *optional*):
+ Argument used when doing sequence summary. Used in for the multiple choice head in
+ [`GPT2DoubleHeadsModel`].
+
+ Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
+ summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
+ [`TFGPT2DoubleHeadsModel`].
+
+ Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
+ summary_first_dropout (`float`, *optional*, defaults to 0.1):
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
+ [`TFGPT2DoubleHeadsModel`].
+
+ The dropout ratio to be used after the projection and activation.
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
+ Scale attention weights by dividing by sqrt(hidden_size)..
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ bos_token_id (`int`, *optional*, defaults to 50256):
+ Id of the beginning of sentence token in the vocabulary.
+ eos_token_id (`int`, *optional*, defaults to 50256):
+ Id of the end of sentence token in the vocabulary.
+ scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
+ Whether to additionally scale attention weights by `1 / layer_idx + 1`.
+ reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
+ Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
+ dot-product/softmax to float() when training with mixed precision.
+
+ Example:
+
+ ```python
+ >>> from transformers import GPT2Config, GPT2Model
+
+ >>> # Initializing a GPT2 configuration
+ >>> configuration = GPT2Config()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = GPT2Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "gpt2"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {
+ "hidden_size": "n_embd",
+ "max_position_embeddings": "n_positions",
+ "num_attention_heads": "n_head",
+ "num_hidden_layers": "n_layer",
+ }
+
+ def __init__(
+ self,
+ vocab_size=50257,
+ n_positions=1024,
+ n_embd=768,
+ n_layer=12,
+ n_head=12,
+ n_inner=None,
+ activation_function="gelu_new",
+ resid_pdrop=0.1,
+ embd_pdrop=0.1,
+ attn_pdrop=0.1,
+ layer_norm_epsilon=1e-5,
+ initializer_range=0.02,
+ summary_type="cls_index",
+ summary_use_proj=True,
+ summary_activation=None,
+ summary_proj_to_labels=True,
+ summary_first_dropout=0.1,
+ scale_attn_weights=True,
+ use_cache=True,
+ bos_token_id=50256,
+ eos_token_id=50256,
+ scale_attn_by_inverse_layer_idx=False,
+ reorder_and_upcast_attn=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.n_positions = n_positions
+ self.n_embd = n_embd
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.n_inner = n_inner
+ self.activation_function = activation_function
+ self.resid_pdrop = resid_pdrop
+ self.embd_pdrop = embd_pdrop
+ self.attn_pdrop = attn_pdrop
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_range = initializer_range
+ self.summary_type = summary_type
+ self.summary_use_proj = summary_use_proj
+ self.summary_activation = summary_activation
+ self.summary_first_dropout = summary_first_dropout
+ self.summary_proj_to_labels = summary_proj_to_labels
+ self.scale_attn_weights = scale_attn_weights
+ self.use_cache = use_cache
+ self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
+ self.reorder_and_upcast_attn = reorder_and_upcast_attn
+
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+
+class GPT2OnnxConfig(OnnxConfigWithPast):
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ task: str = "default",
+ patching_specs: Optional[list[PatchingSpec]] = None,
+ use_past: bool = False,
+ ):
+ super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
+ if not getattr(self._config, "pad_token_id", None):
+ # TODO: how to do that better?
+ self._config.pad_token_id = 0
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
+ if self.use_past:
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
+ common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
+ else:
+ common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
+
+ return common_inputs
+
+ @property
+ def num_layers(self) -> int:
+ return self._config.n_layer
+
+ @property
+ def num_attention_heads(self) -> int:
+ return self._config.n_head
+
+ def generate_dummy_inputs(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ batch_size: int = -1,
+ seq_length: int = -1,
+ is_pair: bool = False,
+ framework: Optional[TensorType] = None,
+ ) -> Mapping[str, Any]:
+ common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
+ tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
+ )
+
+ # We need to order the input in the way they appears in the forward()
+ ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
+
+ # Need to add the past_keys
+ if self.use_past:
+ if not is_torch_available():
+ raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
+ else:
+ import torch
+
+ batch, seqlen = common_inputs["input_ids"].shape
+ # Not using the same length for past_key_values
+ past_key_values_length = seqlen + 2
+ past_shape = (
+ batch,
+ self.num_attention_heads,
+ past_key_values_length,
+ self._config.hidden_size // self.num_attention_heads,
+ )
+ ordered_inputs["past_key_values"] = [
+ (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
+ ]
+
+ ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
+ if self.use_past:
+ mask_dtype = ordered_inputs["attention_mask"].dtype
+ ordered_inputs["attention_mask"] = torch.cat(
+ [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
+ )
+
+ return ordered_inputs
+
+ @property
+ def default_onnx_opset(self) -> int:
+ return 13
+
+
+__all__ = ["GPT2Config", "GPT2OnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/modeling_flax_gpt2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/modeling_flax_gpt2.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e419217c5a3642ee27f6f3df87e1c27c0d5ac79
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/modeling_flax_gpt2.py
@@ -0,0 +1,782 @@
+# coding=utf-8
+# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Optional
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+
+from ...modeling_flax_outputs import (
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
+ FlaxCausalLMOutputWithCrossAttentions,
+)
+from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from .configuration_gpt2 import GPT2Config
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "openai-community/gpt2"
+_CONFIG_FOR_DOC = "GPT2Config"
+
+
+GPT2_START_DOCSTRING = r"""
+
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+GPT2_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class FlaxConv1D(nn.Module):
+ features: int
+ use_bias: bool = True
+ dtype: Any = jnp.float32
+ precision: Any = None
+
+ @nn.compact
+ def __call__(self, inputs):
+ inputs = jnp.asarray(inputs, self.dtype)
+ kernel = self.param("kernel", jax.nn.initializers.normal(stddev=0.02), (self.features, inputs.shape[-1]))
+ kernel = jnp.asarray(kernel.transpose(), self.dtype)
+ y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision)
+ if self.use_bias:
+ bias = self.param("bias", jax.nn.initializers.zeros, (self.features,))
+ bias = jnp.asarray(bias, self.dtype)
+ y = y + bias
+ return y
+
+
+class FlaxGPT2Attention(nn.Module):
+ config: GPT2Config
+ dtype: jnp.dtype = jnp.float32
+ causal: bool = True
+ is_cross_attention: bool = False
+
+ def setup(self):
+ config = self.config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+
+ if self.is_cross_attention:
+ self.c_attn = FlaxConv1D(2 * self.embed_dim, dtype=self.dtype)
+ self.q_attn = FlaxConv1D(self.embed_dim, dtype=self.dtype)
+ else:
+ self.c_attn = FlaxConv1D(3 * self.embed_dim, dtype=self.dtype)
+ self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype)
+
+ self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
+
+ if self.causal:
+ self.causal_mask = make_causal_mask(
+ jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool"
+ )
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
+
+ @nn.compact
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slightly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
+ def __call__(
+ self,
+ hidden_states,
+ key_value_states: Optional[jnp.ndarray] = None,
+ attention_mask=None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ batch_size = hidden_states.shape[0]
+
+ if not is_cross_attention:
+ qkv_out = self.c_attn(hidden_states)
+ query, key, value = jnp.split(qkv_out, 3, axis=2)
+ else:
+ q_out = self.q_attn(hidden_states)
+ (query,) = jnp.split(q_out, 1, axis=2)
+ kv_out = self.c_attn(key_value_states)
+ key, value = jnp.split(kv_out, 2, axis=2)
+
+ query = self._split_heads(query)
+ key = self._split_heads(key)
+ value = self._split_heads(value)
+
+ query_length, key_length = query.shape[1], key.shape[1]
+
+ if self.causal:
+ if self.has_variable("cache", "cached_key"):
+ mask_shift = self.variables["cache"]["cache_index"]
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_mask = lax.dynamic_slice(
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
+ )
+ else:
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+ # combine masks if needed
+ if attention_mask is not None and self.causal:
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_mask)
+ elif self.causal:
+ attention_mask = causal_mask
+ elif attention_mask is not None:
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+
+ dropout_rng = None
+ if not deterministic and self.config.attn_pdrop > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+ key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
+
+ # transform boolean mask into float mask
+ if attention_mask is not None:
+ attention_bias = lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
+ )
+ else:
+ attention_bias = None
+
+ # usual dot product attention
+ attn_weights = dot_product_attention_weights(
+ query,
+ key,
+ bias=attention_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.config.attn_pdrop,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ precision=None,
+ )
+
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
+ attn_output = self._merge_heads(attn_output)
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
+
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
+ return outputs
+
+
+class FlaxGPT2MLP(nn.Module):
+ config: GPT2Config
+ intermediate_size: int
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ embed_dim = self.config.hidden_size
+ self.c_fc = FlaxConv1D(self.intermediate_size, dtype=self.dtype)
+ self.c_proj = FlaxConv1D(embed_dim, dtype=self.dtype)
+ self.act = ACT2FN[self.config.activation_function]
+ self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
+
+ def __call__(self, hidden_states, deterministic: bool = True):
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ return hidden_states
+
+
+class FlaxGPT2Block(nn.Module):
+ config: GPT2Config
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ hidden_size = self.config.hidden_size
+ inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
+
+ self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+ self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
+ self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+
+ if self.config.add_cross_attention:
+ self.crossattention = FlaxGPT2Attention(
+ config=self.config, dtype=self.dtype, causal=False, is_cross_attention=True
+ )
+ self.ln_cross_attn = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+
+ self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ residual = hidden_states
+ hidden_states = self.ln_1(hidden_states)
+ attn_outputs = self.attn(
+ hidden_states,
+ attention_mask=attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ )
+ # residual connection
+ attn_output = attn_outputs[0] # output_attn: a, (attentions)
+ outputs = attn_outputs[1:]
+ # residual connection
+ hidden_states = attn_output + residual
+
+ # Cross-Attention Block
+ if encoder_hidden_states is not None:
+ # add one self-attention block for cross-attention
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
+ "cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+ residual = hidden_states
+ hidden_states = self.ln_cross_attn(hidden_states)
+ cross_attn_outputs = self.crossattention(
+ hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ )
+ attn_output = cross_attn_outputs[0]
+ # residual connection
+ hidden_states = residual + attn_output
+ outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
+
+ residual = hidden_states
+ hidden_states = self.ln_2(hidden_states)
+ feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
+ # residual connection
+ hidden_states = residual + feed_forward_hidden_states
+
+ outputs = (hidden_states,) + outputs
+
+ return outputs
+
+
+class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = GPT2Config
+ base_model_prefix = "transformer"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: GPT2Config,
+ input_shape: tuple = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ if self.config.add_cross_attention:
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
+ encoder_attention_mask = attention_mask
+ module_init_outputs = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ position_ids,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ return_dict=False,
+ )
+ else:
+ module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
+
+ random_params = module_init_outputs["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ def init_cache(self, batch_size, max_length):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ """
+ # init input variables to retrieve cache
+ input_ids = jnp.ones((batch_size, max_length))
+ attention_mask = jnp.ones_like(input_ids)
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
+ )
+ return unfreeze(init_variables["cache"])
+
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ position_ids=None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ params: Optional[dict] = None,
+ past_key_values: Optional[dict] = None,
+ dropout_rng: jax.random.PRNGKey = None,
+ train: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if encoder_hidden_states is not None and encoder_attention_mask is None:
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ batch_size, sequence_length = input_ids.shape
+
+ if position_ids is None:
+ if past_key_values is not None:
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
+
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ if attention_mask is None:
+ attention_mask = jnp.ones((batch_size, sequence_length))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ jnp.array(position_ids, dtype="i4"),
+ encoder_hidden_states,
+ encoder_attention_mask,
+ not train,
+ False,
+ output_attentions,
+ output_hidden_states,
+ return_dict,
+ rngs=rngs,
+ mutable=mutable,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past_key_values = outputs
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past_key_values = outputs
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
+
+ return outputs
+
+
+class FlaxGPT2BlockCollection(nn.Module):
+ config: GPT2Config
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.blocks = [
+ FlaxGPT2Block(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
+ ]
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ for block in self.blocks:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = block(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ # this contains possible `None` values - `FlaxGPT2Module` will filter them out
+ outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
+
+ return outputs
+
+
+class FlaxGPT2Module(nn.Module):
+ config: GPT2Config
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.embed_dim = self.config.hidden_size
+
+ self.wte = nn.Embed(
+ self.config.vocab_size,
+ self.embed_dim,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ dtype=self.dtype,
+ )
+ self.wpe = nn.Embed(
+ self.config.max_position_embeddings,
+ self.embed_dim,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ dtype=self.dtype,
+ )
+ self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
+ self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)
+ self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ deterministic=True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ input_embeds = self.wte(input_ids.astype("i4"))
+ position_embeds = self.wpe(position_ids.astype("i4"))
+
+ hidden_states = input_embeds + position_embeds
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+
+ outputs = self.h(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = outputs[1] + (hidden_states,)
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
+ else:
+ outputs = (hidden_states,) + outputs[1:]
+
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=outputs[1],
+ attentions=outputs[2],
+ cross_attentions=outputs[3],
+ )
+
+
+@add_start_docstrings(
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
+ GPT2_START_DOCSTRING,
+)
+class FlaxGPT2Model(FlaxGPT2PreTrainedModel):
+ module_class = FlaxGPT2Module
+
+
+append_call_sample_docstring(
+ FlaxGPT2Model,
+ _CHECKPOINT_FOR_DOC,
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
+ _CONFIG_FOR_DOC,
+)
+
+
+class FlaxGPT2LMHeadModule(nn.Module):
+ config: GPT2Config
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.transformer = FlaxGPT2Module(self.config, dtype=self.dtype)
+ self.lm_head = nn.Dense(
+ self.config.vocab_size,
+ use_bias=False,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ outputs = self.transformer(
+ input_ids,
+ attention_mask,
+ position_ids,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ if self.config.tie_word_embeddings:
+ shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
+ else:
+ lm_logits = self.lm_head(hidden_states)
+
+ if not return_dict:
+ return (lm_logits,) + outputs[1:]
+
+ return FlaxCausalLMOutputWithCrossAttentions(
+ logits=lm_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """,
+ GPT2_START_DOCSTRING,
+)
+class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
+ module_class = FlaxGPT2LMHeadModule
+
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
+ # initializing the cache
+ batch_size, seq_length = input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since GPT2 uses a causal mask, those positions are masked anyways.
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if attention_mask is not None:
+ position_ids = attention_mask.cumsum(axis=-1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(
+ extended_attention_mask, attention_mask.astype("i4"), (0, 0)
+ )
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "attention_mask": extended_attention_mask,
+ "position_ids": position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
+ return model_kwargs
+
+
+append_call_sample_docstring(
+ FlaxGPT2LMHeadModel,
+ _CHECKPOINT_FOR_DOC,
+ FlaxCausalLMOutputWithCrossAttentions,
+ _CONFIG_FOR_DOC,
+)
+
+
+__all__ = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae0786179464115b880ab5d5b4c771292ad5b2db
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py
@@ -0,0 +1,1638 @@
+# coding=utf-8
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch OpenAI GPT-2 model."""
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN, get_activation
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ auto_docstring,
+ logging,
+)
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.model_parallel_utils import assert_device_map, get_device_map
+from .configuration_gpt2 import GPT2Config
+
+
+logger = logging.get_logger(__name__)
+
+
+def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
+ """Load tf checkpoints in a pytorch model"""
+ try:
+ import re
+
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array.squeeze())
+
+ for name, array in zip(names, arrays):
+ name = name[6:] # skip "model/"
+ name = name.split("/")
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
+ scope_names = re.split(r"(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "w" or scope_names[0] == "g":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "b":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
+ pointer = getattr(pointer, scope_names[0])
+ pointer = getattr(pointer, "weight")
+ else:
+ pointer = getattr(pointer, scope_names[0])
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ try:
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+ except ValueError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
+
+ if module.scale_attn_weights:
+ attn_weights = attn_weights / torch.full(
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
+ )
+
+ # Layer-wise attention scaling
+ if module.scale_attn_by_inverse_layer_idx:
+ attn_weights = attn_weights / float(module.layer_idx + 1)
+
+ if not module.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
+ attn_weights = attn_weights.type(value.dtype)
+ attn_weights = module.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2)
+
+ return attn_output, attn_weights
+
+
+class GPT2Attention(nn.Module):
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
+ super().__init__()
+ self.config = config
+ max_positions = config.max_position_embeddings
+ self.register_buffer(
+ "bias",
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
+ 1, 1, max_positions, max_positions
+ ),
+ persistent=False,
+ )
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
+
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ self.split_size = self.embed_dim
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ self.scale_attn_weights = config.scale_attn_weights
+ self.is_cross_attention = is_cross_attention
+
+ # Layer-wise attention scaling, reordering, and upcasting
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
+ self.layer_idx = layer_idx
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
+
+ if self.is_cross_attention:
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
+ else:
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
+
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
+ self.is_causal = True
+
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
+
+ # Prune conv1d layers
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
+
+ # Update hyper params
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
+ self.num_heads = self.num_heads - len(heads)
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
+ bsz, num_heads, q_seq_len, dk = query.size()
+ _, _, k_seq_len, _ = key.size()
+
+ # Preallocate attn_weights for `baddbmm`
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
+
+ # Compute Scale Factor
+ scale_factor = 1.0
+ if self.scale_attn_weights:
+ scale_factor /= float(value.size(-1)) ** 0.5
+
+ if self.scale_attn_by_inverse_layer_idx:
+ scale_factor /= float(self.layer_idx + 1)
+
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
+ with torch.autocast(query.device.type, enabled=False):
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
+
+ if not self.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
+ if attn_weights.dtype != torch.float32:
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
+ attn_weights = attn_weights.type(value.dtype)
+ attn_weights = self.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2)
+
+ return attn_output, attn_weights
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: Optional[tuple[torch.FloatTensor]],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ **kwargs,
+ ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
+ is_cross_attention = encoder_hidden_states is not None
+ if past_key_values is not None:
+ if isinstance(past_key_values, EncoderDecoderCache):
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_layer from cache
+ curr_past_key_value = past_key_values.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_values.self_attention_cache
+ else:
+ curr_past_key_value = past_key_values
+
+ if is_cross_attention:
+ if not hasattr(self, "q_attn"):
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
+ )
+ query_states = self.q_attn(hidden_states)
+ attention_mask = encoder_attention_mask
+
+ # Try to get key/value states from cache if possible
+ if past_key_values is not None and is_updated:
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
+ value_states = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
+ key_states = key_states.view(shape_kv).transpose(1, 2)
+ value_states = value_states.view(shape_kv).transpose(1, 2)
+ else:
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
+ key_states = key_states.view(shape_kv).transpose(1, 2)
+ value_states = value_states.view(shape_kv).transpose(1, 2)
+
+ shape_q = (*query_states.shape[:-1], -1, self.head_dim)
+ query_states = query_states.view(shape_q).transpose(1, 2)
+
+ if (past_key_values is not None and not is_cross_attention) or (
+ past_key_values is not None and is_cross_attention and not is_updated
+ ):
+ # save all key/value_layer to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention:
+ past_key_values.is_updated[self.layer_idx] = True
+
+ is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
+
+ using_eager = self.config._attn_implementation == "eager"
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ if using_eager and self.reorder_and_upcast_attn:
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
+ query_states, key_states, value_states, attention_mask, head_mask
+ )
+ else:
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ head_mask=head_mask,
+ dropout=self.attn_dropout.p if self.training else 0.0,
+ is_causal=is_causal,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+
+ return attn_output, attn_weights
+
+
+class GPT2MLP(nn.Module):
+ def __init__(self, intermediate_size, config):
+ super().__init__()
+ embed_dim = config.hidden_size
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
+ self.act = ACT2FN[config.activation_function]
+ self.dropout = nn.Dropout(config.resid_pdrop)
+
+ def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.FloatTensor:
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class GPT2Block(GradientCheckpointingLayer):
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ hidden_size = config.hidden_size
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
+
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+ self.attn = GPT2Attention(config=config, layer_idx=layer_idx)
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ if config.add_cross_attention:
+ self.crossattention = GPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx)
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ self.mlp = GPT2MLP(inner_dim, config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: Optional[tuple[torch.FloatTensor]],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]:
+ residual = hidden_states
+ hidden_states = self.ln_1(hidden_states)
+ attn_output, self_attn_weights = self.attn(
+ hidden_states,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ **kwargs,
+ )
+ # residual connection
+ hidden_states = attn_output + residual
+
+ if encoder_hidden_states is not None:
+ # add one self-attention block for cross-attention
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
+ "cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+ residual = hidden_states
+ hidden_states = self.ln_cross_attn(hidden_states)
+ cross_attn_output, cross_attn_weights = self.crossattention(
+ hidden_states,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ # residual connection
+ hidden_states = residual + cross_attn_output
+
+ residual = hidden_states
+ hidden_states = self.ln_2(hidden_states)
+ feed_forward_hidden_states = self.mlp(hidden_states)
+ # residual connection
+ hidden_states = residual + feed_forward_hidden_states
+
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (self_attn_weights,)
+ if encoder_hidden_states is not None:
+ outputs += (cross_attn_weights,)
+
+ return outputs
+
+
+# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->GPT2
+class GPT2SequenceSummary(nn.Module):
+ r"""
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ config ([`GPT2Config`]):
+ The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
+ config class of your model for the default values it uses):
+
+ - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
+
+ - `"last"` -- Take the last token hidden state (like XLNet)
+ - `"first"` -- Take the first token hidden state (like Bert)
+ - `"mean"` -- Take the mean of all tokens hidden states
+ - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
+ - `"attn"` -- Not implemented now, use multi-head attention
+
+ - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
+ - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
+ (otherwise to `config.hidden_size`).
+ - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
+ another string or `None` will add no activation.
+ - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
+ - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
+ """
+
+ def __init__(self, config: GPT2Config):
+ super().__init__()
+
+ self.summary_type = getattr(config, "summary_type", "last")
+ if self.summary_type == "attn":
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
+ raise NotImplementedError
+
+ self.summary = nn.Identity()
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
+ num_classes = config.num_labels
+ else:
+ num_classes = config.hidden_size
+ self.summary = nn.Linear(config.hidden_size, num_classes)
+
+ activation_string = getattr(config, "summary_activation", None)
+ self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
+
+ self.first_dropout = nn.Identity()
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
+
+ self.last_dropout = nn.Identity()
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
+ ) -> torch.FloatTensor:
+ """
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
+ The hidden states of the last layer.
+ cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
+ Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
+
+ Returns:
+ `torch.FloatTensor`: The summary of the sequence hidden states.
+ """
+ if self.summary_type == "last":
+ output = hidden_states[:, -1]
+ elif self.summary_type == "first":
+ output = hidden_states[:, 0]
+ elif self.summary_type == "mean":
+ output = hidden_states.mean(dim=1)
+ elif self.summary_type == "cls_index":
+ if cls_index is None:
+ cls_index = torch.full_like(
+ hidden_states[..., :1, :],
+ hidden_states.shape[-2] - 1,
+ dtype=torch.long,
+ )
+ else:
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
+ elif self.summary_type == "attn":
+ raise NotImplementedError
+
+ output = self.first_dropout(output)
+ output = self.summary(output)
+ output = self.activation(output)
+ output = self.last_dropout(output)
+
+ return output
+
+
+@auto_docstring
+class GPT2PreTrainedModel(PreTrainedModel):
+ config: GPT2Config
+ load_tf_weights = load_tf_weights_in_gpt2
+ base_model_prefix = "transformer"
+ is_parallelizable = True
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GPT2Block"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_attention_backend = True
+
+ _can_compile_fullgraph = True
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, (nn.Linear, Conv1D)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
+ #
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+ for name, p in module.named_parameters():
+ if name == "c_proj.weight":
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of models predicting if two sentences are consecutive or not.
+ """
+)
+class GPT2DoubleHeadsModelOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
+ Multiple choice classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ mc_loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ mc_logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+PARALLELIZE_DOCSTRING = r"""
+ This is an experimental feature and is a subject to change at a moment's notice.
+
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
+ it will evenly distribute blocks across all devices.
+
+ Args:
+ device_map (`dict[int, list]`, *optional*):
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
+ following number of attention modules:
+
+ - openai-community/gpt2: 12
+ - openai-community/gpt2-medium: 24
+ - openai-community/gpt2-large: 36
+ - openai-community/gpt2-xl: 48
+
+ Example:
+
+ ```python
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl")
+ device_map = {
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ }
+ model.parallelize(device_map)
+ ```
+"""
+DEPARALLELIZE_DOCSTRING = r"""
+ Moves the model to cpu from a model parallel state.
+
+ Example:
+
+ ```python
+ # On a 4 GPU machine with openai-community/gpt2-large:
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large")
+ device_map = {
+ 0: [0, 1, 2, 3, 4, 5, 6, 7],
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
+ }
+ model.parallelize(device_map) # Splits the model across several devices
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
+ ```
+"""
+
+
+@auto_docstring
+class GPT2Model(GPT2PreTrainedModel):
+ _supports_param_buffer_assignment = False
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embed_dim = config.hidden_size
+
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
+
+ self.drop = nn.Dropout(config.embd_pdrop)
+ self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+ self.gradient_checkpointing = False
+ self._attn_implementation = config._attn_implementation
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
+ def parallelize(self, device_map=None):
+ # Check validity of device_map
+ warnings.warn(
+ "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
+ " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
+ " ...}",
+ FutureWarning,
+ )
+ self.device_map = (
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
+ )
+ assert_device_map(self.device_map, len(self.h))
+ self.model_parallel = True
+ self.first_device = "cpu" if "cpu" in self.device_map else "cuda:" + str(min(self.device_map.keys()))
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
+ self.wte = self.wte.to(self.first_device)
+ self.wpe = self.wpe.to(self.first_device)
+ # Load onto devices
+ for k, v in self.device_map.items():
+ for block in v:
+ cuda_device = "cuda:" + str(k)
+ self.h[block] = self.h[block].to(cuda_device)
+ # ln_f to last
+ self.ln_f = self.ln_f.to(self.last_device)
+
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
+ def deparallelize(self):
+ warnings.warn(
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
+ FutureWarning,
+ )
+ self.model_parallel = False
+ self.device_map = None
+ self.first_device = "cpu"
+ self.last_device = "cpu"
+ self.wte = self.wte.to("cpu")
+ self.wpe = self.wpe.to("cpu")
+ for index in range(len(self.h)):
+ self.h[index] = self.h[index].to("cpu")
+ self.ln_f = self.ln_f.to("cpu")
+ torch.cuda.empty_cache()
+
+ def get_input_embeddings(self):
+ return self.wte
+
+ def set_input_embeddings(self, new_embeddings):
+ self.wte = new_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
+ """
+ for layer, heads in heads_to_prune.items():
+ self.h[layer].attn.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size = inputs_embeds.shape[0]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
+ if use_cache:
+ if past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+ elif isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
+ "You should pass an instance of `Cache` instead, e.g. "
+ "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
+ past_key_values = EncoderDecoderCache(past_key_values, DynamicCache(config=self.config))
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ position_embeds = self.wpe(position_ids)
+ hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
+
+ # Attention mask.
+ # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
+ if attention_mask is not None and attention_mask.ndim < 4:
+ attention_mask = attention_mask.view(batch_size, -1)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ if _use_sdpa:
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+ elif self._attn_implementation != "flash_attention_2":
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # head_mask has shape n_layer x batch x n_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if token_type_ids is not None:
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
+
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, block in enumerate(self.h):
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ if isinstance(head_mask, torch.Tensor):
+ head_mask = head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = block(
+ hidden_states,
+ past_key_values if not (self.gradient_checkpointing and self.training) else None,
+ cache_position,
+ causal_mask,
+ head_mask[i],
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (outputs[2],)
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ past_key_values = past_key_values if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """
+)
+class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = GPT2Model(config)
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
+ def parallelize(self, device_map=None):
+ warnings.warn(
+ "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
+ " 0, 'transformer.h.1': 1, ...}",
+ FutureWarning,
+ )
+ self.device_map = (
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
+ if device_map is None
+ else device_map
+ )
+ assert_device_map(self.device_map, len(self.transformer.h))
+ self.transformer.parallelize(self.device_map)
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
+ self.model_parallel = True
+
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
+ def deparallelize(self):
+ warnings.warn(
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
+ FutureWarning,
+ )
+ self.transformer.deparallelize()
+ self.transformer = self.transformer.to("cpu")
+ self.lm_head = self.lm_head.to("cpu")
+ self.model_parallel = False
+ torch.cuda.empty_cache()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+
+ # Set device for model parallelism
+ if self.model_parallel:
+ torch.cuda.set_device(self.transformer.first_device)
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
+
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ # Flatten the tokens
+ loss = self.loss_function(
+ logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ cross_attentions=transformer_outputs.cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
+ input sequence).
+ """
+)
+class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ config.num_labels = 1
+ self.transformer = GPT2Model(config)
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+ self.multiple_choice_head = GPT2SequenceSummary(config)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
+ def parallelize(self, device_map=None):
+ warnings.warn(
+ "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
+ " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
+ " own `device_map` but it needs to be a dictionary module_name to device, so for instance"
+ " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
+ FutureWarning,
+ )
+ self.device_map = (
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
+ if device_map is None
+ else device_map
+ )
+ assert_device_map(self.device_map, len(self.transformer.h))
+ self.transformer.parallelize(self.device_map)
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
+ self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device)
+ self.model_parallel = True
+
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
+ def deparallelize(self):
+ warnings.warn(
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
+ FutureWarning,
+ )
+ self.transformer.deparallelize()
+ self.transformer = self.transformer.to("cpu")
+ self.lm_head = self.lm_head.to("cpu")
+ self.multiple_choice_head = self.multiple_choice_head.to("cpu")
+ self.model_parallel = False
+ torch.cuda.empty_cache()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ mc_token_ids: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ mc_labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, GPT2DoubleHeadsModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
+ 1]`.
+ labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
+ `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
+ mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
+ where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
+ >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2")
+
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
+ >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
+ >>> # Update the model embeddings with the new vocabulary size
+ >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
+
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
+ >>> encoded_choices = [tokenizer.encode(s) for s in choices]
+ >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
+
+ >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
+ >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
+
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
+ >>> lm_logits = outputs.logits
+ >>> mc_logits = outputs.mc_logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+
+ # Set device for model parallelism
+ if self.model_parallel:
+ torch.cuda.set_device(self.transformer.first_device)
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
+
+ lm_logits = self.lm_head(hidden_states)
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
+
+ mc_loss = None
+ if mc_labels is not None:
+ loss_fct = CrossEntropyLoss()
+ mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
+ lm_loss = None
+ if labels is not None:
+ labels = labels.to(lm_logits.device)
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits, mc_logits) + transformer_outputs[1:]
+ if mc_loss is not None:
+ output = (mc_loss,) + output
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return GPT2DoubleHeadsModelOutput(
+ loss=lm_loss,
+ mc_loss=mc_loss,
+ logits=lm_logits,
+ mc_logits=mc_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The GPT2 Model transformer with a sequence classification head on top (linear layer).
+
+ [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-1) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+class GPT2ForSequenceClassification(GPT2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = GPT2Model(config)
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class GPT2ForTokenClassification(GPT2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.transformer = GPT2Model(config)
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
+ classifier_dropout = config.classifier_dropout
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, TokenClassifierOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = GPT2Model(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, QuestionAnsweringModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "GPT2DoubleHeadsModel",
+ "GPT2ForQuestionAnswering",
+ "GPT2ForSequenceClassification",
+ "GPT2ForTokenClassification",
+ "GPT2LMHeadModel",
+ "GPT2Model",
+ "GPT2PreTrainedModel",
+ "load_tf_weights_in_gpt2",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/modeling_tf_gpt2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/modeling_tf_gpt2.py
new file mode 100644
index 0000000000000000000000000000000000000000..42e23fc290151f09d47a30efca1cb7f4e4a3d669
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/modeling_tf_gpt2.py
@@ -0,0 +1,1238 @@
+# coding=utf-8
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 OpenAI GPT-2 model."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutputWithPastAndCrossAttentions,
+ TFCausalLMOutputWithCrossAttentions,
+ TFSequenceClassifierOutputWithPast,
+)
+from ...modeling_tf_utils import (
+ TFCausalLanguageModelingLoss,
+ TFConv1D,
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ TFSequenceSummary,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_gpt2 import GPT2Config
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "openai-community/gpt2"
+_CONFIG_FOR_DOC = "GPT2Config"
+
+
+class TFAttention(keras.layers.Layer):
+ def __init__(self, nx, config, scale=False, is_cross_attention=False, **kwargs):
+ super().__init__(**kwargs)
+
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
+ # [switch nx => n_state from Block to Attention to keep identical to TF implementation]
+ assert n_state % config.n_head == 0
+ self.n_head = config.n_head
+ self.split_size = n_state
+ self.scale = scale
+ self.output_attentions = config.output_attentions
+
+ self.is_cross_attention = is_cross_attention
+
+ if self.is_cross_attention:
+ self.c_attn = TFConv1D(n_state * 2, nx, initializer_range=config.initializer_range, name="c_attn")
+ self.q_attn = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="q_attn")
+ else:
+ self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
+
+ self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
+ self.attn_dropout = keras.layers.Dropout(config.attn_pdrop)
+ self.resid_dropout = keras.layers.Dropout(config.resid_pdrop)
+ self.pruned_heads = set()
+ self.embed_dim = n_state
+
+ def prune_heads(self, heads):
+ pass
+
+ @staticmethod
+ def causal_attention_mask(nd, ns, dtype):
+ """
+ 1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
+ -1, ns-nd), but doesn't produce garbage on TPUs.
+ """
+ i = tf.range(nd)[:, None]
+ j = tf.range(ns)
+ m = i >= j - ns + nd
+ return tf.cast(m, dtype)
+
+ def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
+ # q, k, v have shape [batch, heads, sequence, features]
+ w = tf.matmul(q, k, transpose_b=True)
+ if self.scale:
+ dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores
+ w = w / tf.math.sqrt(dk)
+
+ if not self.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+
+ # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
+ _, _, nd, ns = shape_list(w)
+ b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
+ b = tf.reshape(b, [1, 1, nd, ns])
+ w = w * b - 1e4 * (1 - b)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attention_mask = tf.cast(attention_mask, dtype=w.dtype)
+ w = w + attention_mask
+
+ w = stable_softmax(w, axis=-1)
+ w = self.attn_dropout(w, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ w = w * head_mask
+
+ outputs = [tf.matmul(w, v)]
+ if output_attentions:
+ outputs.append(w)
+ return outputs
+
+ def merge_heads(self, x):
+ x = tf.transpose(x, [0, 2, 1, 3])
+ x_shape = shape_list(x)
+ new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
+ return tf.reshape(x, new_x_shape)
+
+ def split_heads(self, x):
+ x_shape = shape_list(x)
+ new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
+ x = tf.reshape(x, new_x_shape)
+ return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
+
+ def call(
+ self,
+ x,
+ layer_past,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ use_cache,
+ output_attentions,
+ training=False,
+ ):
+ if encoder_hidden_states is not None:
+ if not hasattr(self, "q_attn"):
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
+ )
+
+ query = self.q_attn(x)
+ kv_out = self.c_attn(encoder_hidden_states)
+ key, value = tf.split(kv_out, 2, axis=2)
+ attention_mask = encoder_attention_mask
+ else:
+ x = self.c_attn(x)
+ query, key, value = tf.split(x, 3, axis=2)
+
+ query = self.split_heads(query)
+ key = self.split_heads(key)
+ value = self.split_heads(value)
+ if layer_past is not None:
+ past_key, past_value = tf.unstack(layer_past, axis=0, num=2)
+ key = tf.concat([past_key, key], axis=-2)
+ value = tf.concat([past_value, value], axis=-2)
+
+ # to cope with keras serialization
+ if use_cache:
+ present = tf.stack([key, value], axis=0)
+ else:
+ present = (None,)
+
+ attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training)
+ a = attn_outputs[0]
+
+ a = self.merge_heads(a)
+ a = self.c_proj(a)
+ a = self.resid_dropout(a, training=training)
+
+ outputs = [a, present] + attn_outputs[1:]
+ return outputs # a, present, (attentions)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if self.is_cross_attention:
+ c_attn_shape = 2 * self.embed_dim
+ else:
+ c_attn_shape = 3 * self.embed_dim
+ if getattr(self, "c_proj", None) is not None:
+ with tf.name_scope(self.c_proj.name):
+ self.c_proj.build([None, None, self.embed_dim])
+ if getattr(self, "c_attn", None) is not None:
+ with tf.name_scope(self.c_attn.name):
+ self.c_attn.build([None, None, c_attn_shape])
+ if getattr(self, "q_attn", None) is not None:
+ with tf.name_scope(self.q_attn.name):
+ self.q_attn.build([None, None, self.embed_dim])
+
+
+class TFMLP(keras.layers.Layer):
+ def __init__(self, n_state, config, **kwargs):
+ super().__init__(**kwargs)
+ nx = config.n_embd
+ self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc")
+ self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj")
+ self.act = get_tf_activation(config.activation_function)
+ self.dropout = keras.layers.Dropout(config.resid_pdrop)
+ self.intermediate_size = n_state
+ self.embed_dim = nx
+
+ def call(self, x, training=False):
+ h = self.act(self.c_fc(x))
+ h2 = self.c_proj(h)
+ h2 = self.dropout(h2, training=training)
+ return h2
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "c_fc", None) is not None:
+ with tf.name_scope(self.c_fc.name):
+ self.c_fc.build([None, None, self.intermediate_size])
+ if getattr(self, "c_proj", None) is not None:
+ with tf.name_scope(self.c_proj.name):
+ self.c_proj.build([None, None, self.embed_dim])
+
+
+class TFBlock(keras.layers.Layer):
+ def __init__(self, config, scale=False, **kwargs):
+ super().__init__(**kwargs)
+ nx = config.n_embd
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * nx
+ self.ln_1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
+ self.attn = TFAttention(nx, config, scale, name="attn")
+ self.ln_2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
+
+ if config.add_cross_attention:
+ self.crossattention = TFAttention(nx, config, scale, name="crossattention", is_cross_attention=True)
+ self.ln_cross_attn = keras.layers.LayerNormalization(
+ epsilon=config.layer_norm_epsilon, name="ln_cross_attn"
+ )
+
+ self.mlp = TFMLP(inner_dim, config, name="mlp")
+ self.hidden_size = config.hidden_size
+
+ def call(
+ self,
+ x,
+ layer_past,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ use_cache,
+ output_attentions,
+ training=False,
+ ):
+ a = self.ln_1(x)
+ output_attn = self.attn(
+ a,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ a = output_attn[0] # output_attn: a, present, (attentions)
+ outputs = output_attn[1:]
+ x = x + a
+
+ # Cross-Attention Block
+ if encoder_hidden_states is not None:
+ # add one self-attention block for cross-attention
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
+ "cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+
+ ca = self.ln_cross_attn(x)
+ output_cross_attn = self.crossattention(
+ ca,
+ layer_past=None,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=False,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ ca = output_cross_attn[0] # output_attn: a, present, (cross_attentions)
+ x = x + ca
+ outputs = outputs + output_cross_attn[2:] # add cross attentions if we output attention weights
+
+ m = self.ln_2(x)
+ m = self.mlp(m, training=training)
+ x = x + m
+
+ outputs = [x] + outputs
+ return outputs # x, present, (attentions, cross_attentions)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "ln_1", None) is not None:
+ with tf.name_scope(self.ln_1.name):
+ self.ln_1.build([None, None, self.hidden_size])
+ if getattr(self, "attn", None) is not None:
+ with tf.name_scope(self.attn.name):
+ self.attn.build(None)
+ if getattr(self, "ln_2", None) is not None:
+ with tf.name_scope(self.ln_2.name):
+ self.ln_2.build([None, None, self.hidden_size])
+ if getattr(self, "mlp", None) is not None:
+ with tf.name_scope(self.mlp.name):
+ self.mlp.build(None)
+ if getattr(self, "crossattention", None) is not None:
+ with tf.name_scope(self.crossattention.name):
+ self.crossattention.build(None)
+ if getattr(self, "ln_cross_attn", None) is not None:
+ with tf.name_scope(self.ln_cross_attn.name):
+ self.ln_cross_attn.build([None, None, self.hidden_size])
+
+
+@keras_serializable
+class TFGPT2MainLayer(keras.layers.Layer):
+ config_class = GPT2Config
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ self.config = config
+ self.output_attentions = config.output_attentions
+ self.output_hidden_states = config.output_hidden_states
+ self.use_cache = config.use_cache
+ self.return_dict = config.use_return_dict
+
+ self.num_hidden_layers = config.n_layer
+ self.n_embd = config.n_embd
+ self.n_positions = config.n_positions
+ self.initializer_range = config.initializer_range
+
+ self.wte = keras.layers.Embedding(
+ input_dim=config.vocab_size,
+ output_dim=config.hidden_size,
+ embeddings_initializer=get_initializer(config.initializer_range),
+ name="wte",
+ )
+ self.wpe = keras.layers.Embedding(
+ input_dim=config.n_positions,
+ output_dim=config.n_embd,
+ embeddings_initializer=get_initializer(config.initializer_range),
+ name="wpe",
+ )
+ self.drop = keras.layers.Dropout(config.embd_pdrop)
+ self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)]
+ self.ln_f = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
+ self.embed_dim = config.hidden_size
+
+ def get_input_embeddings(self):
+ return self.wte
+
+ def set_input_embeddings(self, new_embeddings):
+ self.wte = new_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
+ """
+ raise NotImplementedError
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]:
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if past_key_values is None:
+ past_length = 0
+ past_key_values = [None] * len(self.h)
+ else:
+ past_length = shape_list(past_key_values[0][0])[-2]
+
+ if position_ids is None:
+ position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
+
+ if attention_mask is not None:
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask_shape = shape_list(attention_mask)
+ attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]))
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ one_cst = tf.constant(1.0)
+ attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
+ attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0))
+
+ # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
+ if self.config.add_cross_attention and encoder_attention_mask is not None:
+ # If a 2D ou 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=encoder_hidden_states.dtype)
+ num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
+ if num_dims_encoder_attention_mask == 3:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+ if num_dims_encoder_attention_mask == 2:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
+ # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
+ # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
+
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
+ else:
+ encoder_extended_attention_mask = None
+
+ encoder_attention_mask = encoder_extended_attention_mask
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.num_hidden_layers
+ # head_mask = tf.constant([0] * self.num_hidden_layers)
+
+ position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
+
+ if inputs_embeds is None:
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
+ inputs_embeds = self.wte(input_ids)
+
+ position_embeds = self.wpe(position_ids)
+
+ if token_type_ids is not None:
+ token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
+ token_type_embeds = self.wte(token_type_ids)
+ else:
+ token_type_embeds = tf.constant(0.0)
+
+ position_embeds = tf.cast(position_embeds, dtype=inputs_embeds.dtype)
+ token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype)
+ hidden_states = inputs_embeds + position_embeds + token_type_embeds
+ hidden_states = self.drop(hidden_states, training=training)
+
+ output_shape = input_shape + [shape_list(hidden_states)[-1]]
+
+ presents = () if use_cache else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
+
+ outputs = block(
+ hidden_states,
+ layer_past,
+ attention_mask,
+ head_mask[i],
+ encoder_hidden_states,
+ encoder_attention_mask,
+ use_cache,
+ output_attentions,
+ training=training,
+ )
+
+ hidden_states, present = outputs[:2]
+ if use_cache:
+ presents = presents + (present,)
+
+ if output_attentions:
+ all_attentions = all_attentions + (outputs[2],)
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ all_cross_attentions = all_cross_attentions + (outputs[3],)
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = tf.reshape(hidden_states, output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if output_attentions:
+ # let the number of heads free (-1) so we can extract attention even after head pruning
+ attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
+ all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, presents, all_hidden_states, all_attentions, all_cross_attentions]
+ if v is not None
+ )
+
+ return TFBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "wte", None) is not None:
+ with tf.name_scope(self.wte.name):
+ self.wte.build(None)
+ if getattr(self, "wpe", None) is not None:
+ with tf.name_scope(self.wpe.name):
+ self.wpe.build(None)
+ if getattr(self, "ln_f", None) is not None:
+ with tf.name_scope(self.ln_f.name):
+ self.ln_f.build([None, None, self.embed_dim])
+ if getattr(self, "h", None) is not None:
+ for layer in self.h:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+class TFGPT2PreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = GPT2Config
+ base_model_prefix = "transformer"
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"h.\d+.attn.bias", r"h.\d+.crossattention.bias"]
+
+ @property
+ def input_signature(self):
+ # Although GPT-2 supports token_type_ids in theory, in practice they are rarely used, and the implementation
+ # means that passing token_type_ids=0 yields different outputs from token_type_ids=None.
+ # Therefore, we remove the token_type_ids argument by default, even though it would usually be included.
+ return {
+ "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
+ "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
+ }
+
+
+@dataclass
+class TFGPT2DoubleHeadsModelOutput(ModelOutput):
+ """
+ Base class for outputs of models predicting if two sentences are consecutive or not.
+
+ Args:
+ logits (`tf.Tensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ mc_logits (`tf.Tensor` of shape `(batch_size, num_choices)`):
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
+ past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ logits: tf.Tensor | None = None
+ mc_logits: tf.Tensor | None = None
+ past_key_values: list[tf.Tensor] | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+GPT2_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+GPT2_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0].shape[-2]`
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only input IDs that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ past_key_values (`list[tf.Tensor]` of length `config.n_layers`):
+ Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
+ `past_key_values` output below). Can be used to speed up sequential decoding. The token ids which have
+ their past given to this model should not be passed as input ids as they have already been computed.
+ attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
+ `len(past_key_values) + len(input_ids)`
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, input_ids_length)`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, input_ids_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`tf.Tensor` of shape `(batch_size, input_ids_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
+ GPT2_START_DOCSTRING,
+)
+class TFGPT2Model(TFGPT2PreTrainedModel):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.transformer = TFGPT2MainLayer(config, name="transformer")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPastAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]:
+ r"""
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`)
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have
+ their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past`). Set to `False` during training, `True` during generation
+ """
+
+ outputs = self.transformer(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+
+
+@add_start_docstrings(
+ """
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """,
+ GPT2_START_DOCSTRING,
+)
+class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.transformer = TFGPT2MainLayer(config, name="transformer")
+
+ def get_output_embeddings(self):
+ return self.get_input_embeddings()
+
+ def set_output_embeddings(self, value):
+ self.set_input_embeddings(value)
+
+ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
+ token_type_ids = kwargs.get("token_type_ids")
+ # only last token for inputs_ids if past is defined in kwargs
+ if past_key_values:
+ inputs = tf.expand_dims(inputs[:, -1], -1)
+ if token_type_ids is not None:
+ token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
+
+ position_ids = kwargs.get("position_ids")
+ attention_mask = kwargs.get("attention_mask")
+
+ if attention_mask is not None and position_ids is None:
+ position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
+ if past_key_values:
+ position_ids = tf.expand_dims(position_ids[:, -1], -1)
+
+ return {
+ "input_ids": inputs,
+ "attention_mask": attention_mask,
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "token_type_ids": token_type_ids,
+ }
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFCausalLMOutputWithCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFCausalLMOutputWithCrossAttentions | tuple[tf.Tensor]:
+ r"""
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`)
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have
+ their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past`). Set to `False` during training, `True` during generation
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = tf.matmul(hidden_states, self.transformer.wte.weights, transpose_b=True)
+
+ loss = None
+ if labels is not None:
+ # shift labels to the left and cut last logit token
+ shifted_logits = logits[:, :-1]
+ labels = labels[:, 1:]
+ loss = self.hf_compute_loss(labels, shifted_logits)
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFCausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ cross_attentions=transformer_outputs.cross_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+
+
+@add_start_docstrings(
+ """
+ The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
+ input sequence).
+ """,
+ GPT2_START_DOCSTRING,
+)
+class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ config.num_labels = 1
+ self.transformer = TFGPT2MainLayer(config, name="transformer")
+ self.multiple_choice_head = TFSequenceSummary(
+ config, initializer_range=config.initializer_range, name="multiple_choice_head"
+ )
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ mc_token_ids: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> TFGPT2DoubleHeadsModelOutput | tuple[tf.Tensor]:
+ r"""
+ mc_token_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
+ 1]`.
+
+ Return:
+
+ Examples:
+
+ ```python
+ >>> import tensorflow as tf
+ >>> from transformers import AutoTokenizer, TFGPT2DoubleHeadsModel
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
+ >>> model = TFGPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2")
+
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
+ >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
+
+ >>> embedding_layer = model.resize_token_embeddings(
+ ... len(tokenizer)
+ ... ) # Update the model embeddings with the new vocabulary size
+
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
+ >>> encoded_choices = [tokenizer.encode(s) for s in choices]
+ >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
+
+ >>> input_ids = tf.constant(encoded_choices)[None, :] # Batch size: 1, number of choices: 2
+ >>> mc_token_ids = tf.constant([cls_token_location]) # Batch size: 1
+
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
+ >>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
+ ```"""
+
+ if input_ids is not None:
+ input_shapes = shape_list(input_ids)
+ else:
+ input_shapes = shape_list(inputs_embeds)[:-1]
+
+ seq_length = input_shapes[-1]
+ flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
+ flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
+ flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
+ flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
+ transformer_outputs = self.transformer(
+ input_ids=flat_input_ids,
+ past_key_values=past_key_values,
+ attention_mask=flat_attention_mask,
+ token_type_ids=flat_token_type_ids,
+ position_ids=flat_position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_states = transformer_outputs[0]
+ hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
+ if return_dict and output_hidden_states:
+ # We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the
+ # input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged)
+ all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,)
+ else:
+ all_hidden_states = None
+ lm_logits = tf.matmul(hidden_states, self.transformer.wte.weights, transpose_b=True)
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
+ mc_logits = tf.squeeze(mc_logits, axis=-1)
+
+ if not return_dict:
+ return (lm_logits, mc_logits) + transformer_outputs[1:]
+
+ return TFGPT2DoubleHeadsModelOutput(
+ logits=lm_logits,
+ mc_logits=mc_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @property
+ def input_signature(self):
+ return {
+ "input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
+ "attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
+ "mc_token_ids": tf.TensorSpec((None, None), tf.int32, name="mc_token_ids"),
+ }
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+ if getattr(self, "multiple_choice_head", None) is not None:
+ with tf.name_scope(self.multiple_choice_head.name):
+ self.multiple_choice_head.build(None)
+
+
+@add_start_docstrings(
+ """
+ The GPT2 Model transformer with a sequence classification head on top (linear layer).
+
+ [`TFGPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-1) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ GPT2_START_DOCSTRING,
+)
+class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+ self.score = keras.layers.Dense(
+ config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="score",
+ use_bias=False,
+ )
+ self.transformer = TFGPT2MainLayer(config, name="transformer")
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint="microsoft/DialogRPT-updown",
+ output_type=TFSequenceClassifierOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFSequenceClassifierOutputWithPast | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+ logits_shape = shape_list(logits)
+ batch_size = logits_shape[0]
+
+ if self.config.pad_token_id is None:
+ last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
+ else:
+ if input_ids is not None:
+ token_indices = tf.range(shape_list(input_ids)[-1])
+ non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
+ last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
+ else:
+ last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+ loss = None
+
+ pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
+
+ if labels is not None:
+ if self.config.pad_token_id is None and logits_shape[0] != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+
+ loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
+
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "score", None) is not None:
+ with tf.name_scope(self.score.name):
+ self.score.build([None, None, self.config.n_embd])
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+
+
+__all__ = [
+ "TFGPT2DoubleHeadsModel",
+ "TFGPT2ForSequenceClassification",
+ "TFGPT2LMHeadModel",
+ "TFGPT2MainLayer",
+ "TFGPT2Model",
+ "TFGPT2PreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/tokenization_gpt2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/tokenization_gpt2.py
new file mode 100644
index 0000000000000000000000000000000000000000..608164ef2d83ab15bf7f99d33f9c6eb56ed1fcff
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/tokenization_gpt2.py
@@ -0,0 +1,334 @@
+# coding=utf-8
+# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for OpenAI GPT."""
+
+import json
+import os
+from functools import lru_cache
+from typing import Optional
+
+import regex as re
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+ "vocab_file": "vocab.json",
+ "merges_file": "merges.txt",
+}
+
+
+@lru_cache
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+ characters the bpe code barfs on.
+
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
+ tables between utf-8 bytes and unicode strings.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+class GPT2Tokenizer(PreTrainedTokenizer):
+ """
+ Construct a GPT-2 tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```python
+ >>> from transformers import GPT2Tokenizer
+
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
+ >>> tokenizer("Hello world")["input_ids"]
+ [15496, 995]
+
+ >>> tokenizer(" Hello world")["input_ids"]
+ [18435, 995]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
+ call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The beginning of sequence token.
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The end of sequence token.
+ pad_token (`str`, *optional*):
+ The token used for padding, for example when batching sequences of different lengths.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (GPT2 tokenizer detect beginning of words by the preceding space).
+ add_bos_token (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial beginning of sentence token to the input. This allows to treat the leading
+ word just as any other word.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ errors="replace",
+ unk_token="<|endoftext|>",
+ bos_token="<|endoftext|>",
+ eos_token="<|endoftext|>",
+ pad_token=None,
+ add_prefix_space=False,
+ add_bos_token=False,
+ **kwargs,
+ ):
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
+
+ self.add_bos_token = add_bos_token
+
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ bpe_merges = merges_handle.read().split("\n")[1:-1]
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {}
+ self.add_prefix_space = add_prefix_space
+
+ # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+
+ super().__init__(
+ errors=errors,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ add_prefix_space=add_prefix_space,
+ add_bos_token=add_bos_token,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self):
+ return len(self.encoder)
+
+ def get_vocab(self):
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ if self.add_bos_token:
+ bos_token_ids = [self.bos_token_id]
+ else:
+ bos_token_ids = []
+
+ output = bos_token_ids + token_ids_0
+
+ if token_ids_1 is None:
+ return output
+
+ return output + bos_token_ids + token_ids_1
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if not self.add_bos_token:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0))
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
+
+ def _tokenize(self, text):
+ """Tokenize a string."""
+ bpe_tokens = []
+ for token in re.findall(self.pat, text):
+ token = "".join(
+ self.byte_encoder[b] for b in token.encode("utf-8")
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ text = "".join(tokens)
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
+ return text
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ merge_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+ )
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write("#version: 0.2\n")
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!"
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return vocab_file, merge_file
+
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
+ add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
+ if is_split_into_words or add_prefix_space:
+ text = " " + text
+ return (text, kwargs)
+
+
+__all__ = ["GPT2Tokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/tokenization_gpt2_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/tokenization_gpt2_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..f81c155e864476cf49c24f91a0235c939f42d3e0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/tokenization_gpt2_fast.py
@@ -0,0 +1,133 @@
+# coding=utf-8
+# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for OpenAI GPT."""
+
+from typing import Optional
+
+from ...tokenization_utils_base import BatchEncoding
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+from .tokenization_gpt2 import GPT2Tokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
+
+
+class GPT2TokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" GPT-2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
+ Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```python
+ >>> from transformers import GPT2TokenizerFast
+
+ >>> tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
+ >>> tokenizer("Hello world")["input_ids"]
+ [15496, 995]
+
+ >>> tokenizer(" Hello world")["input_ids"]
+ [18435, 995]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
+ the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`, *optional*):
+ Path to the vocabulary file.
+ merges_file (`str`, *optional*):
+ Path to the merges file.
+ tokenizer_file (`str`, *optional*):
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
+ contains everything needed to load the tokenizer.
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The beginning of sequence token.
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The end of sequence token.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (GPT2 tokenizer detect beginning of words by the preceding space).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = GPT2Tokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ merges_file=None,
+ tokenizer_file=None,
+ unk_token="<|endoftext|>",
+ bos_token="<|endoftext|>",
+ eos_token="<|endoftext|>",
+ add_prefix_space=False,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file=vocab_file,
+ merges_file=merges_file,
+ tokenizer_file=tokenizer_file,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ add_prefix_space=add_prefix_space,
+ **kwargs,
+ )
+
+ self.add_bos_token = kwargs.pop("add_bos_token", False)
+
+ def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+ assert self.add_prefix_space or not is_split_into_words, (
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
+ "to use it with pretokenized inputs."
+ )
+
+ return super()._batch_encode_plus(*args, **kwargs)
+
+ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+
+ assert self.add_prefix_space or not is_split_into_words, (
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
+ "to use it with pretokenized inputs."
+ )
+
+ return super()._encode_plus(*args, **kwargs)
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+
+__all__ = ["GPT2TokenizerFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/tokenization_gpt2_tf.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/tokenization_gpt2_tf.py
new file mode 100644
index 0000000000000000000000000000000000000000..145a45da0db6d36f75f5cec6091027e36541184e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt2/tokenization_gpt2_tf.py
@@ -0,0 +1,119 @@
+import os
+from typing import Optional, Union
+
+import tensorflow as tf
+from tensorflow_text import pad_model_inputs
+
+from ...modeling_tf_utils import keras
+from ...utils.import_utils import is_keras_nlp_available, requires
+from .tokenization_gpt2 import GPT2Tokenizer
+
+
+if is_keras_nlp_available():
+ from keras_nlp.tokenizers import BytePairTokenizer
+
+
+@requires(backends=("keras_nlp",))
+class TFGPT2Tokenizer(keras.layers.Layer):
+ """
+ This is an in-graph tokenizer for GPT2. It should be initialized similarly to other tokenizers, using the
+ `from_pretrained()` method. It can also be initialized with the `from_tokenizer()` method, which imports settings
+ from an existing standard tokenizer object.
+
+ In-graph tokenizers, unlike other Hugging Face tokenizers, are actually Keras layers and are designed to be run
+ when the model is called, rather than during preprocessing. As a result, they have somewhat more limited options
+ than standard tokenizer classes. They are most useful when you want to create an end-to-end model that goes
+ straight from `tf.string` inputs to outputs.
+
+ Args:
+ vocab (dict[str, int]): Vocabulary dict for Byte Pair Tokenizer
+ merges (list[str]): Merges list for Byte Pair Tokenizer
+ """
+
+ def __init__(
+ self,
+ vocab: dict[str, int],
+ merges: list[str],
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ ):
+ super().__init__()
+ self.pad_token_id = pad_token_id
+ self.max_length = max_length
+ self.vocab = vocab
+ self.merges = merges
+
+ self.tf_tokenizer = BytePairTokenizer(vocab, merges, sequence_length=max_length)
+
+ @classmethod
+ def from_tokenizer(cls, tokenizer: GPT2Tokenizer, *args, **kwargs):
+ """Creates TFGPT2Tokenizer from GPT2Tokenizer
+
+ Args:
+ tokenizer (GPT2Tokenizer)
+
+ Examples:
+
+ ```python
+ from transformers import AutoTokenizer, TFGPT2Tokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
+ tf_tokenizer = TFGPT2Tokenizer.from_tokenizer(tokenizer)
+ ```
+ """
+ merges = [" ".join(m) for m in tokenizer.bpe_ranks]
+ vocab = tokenizer.get_vocab()
+ return cls(vocab, merges, *args, **kwargs)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
+ """Creates TFGPT2Tokenizer from pretrained GPT2Tokenizer
+
+ Args:
+ pretrained_model_name_or_path (Union[str, os.PathLike]): Path to pretrained model
+
+ Examples:
+
+ ```python
+ from transformers import TFGPT2Tokenizer
+
+ tf_tokenizer = TFGPT2Tokenizer.from_pretrained("openai-community/gpt2")
+ ```
+ """
+ tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
+ return cls.from_tokenizer(tokenizer, *init_inputs, **kwargs)
+
+ @classmethod
+ def from_config(cls, config):
+ """Creates TFGPT2Tokenizer from configurations
+
+ Args:
+ config (Dict): Dictionary with keys such as stated in `get_config`.
+ """
+ return cls(**config)
+
+ def get_config(self):
+ return {
+ "vocab": self.vocab,
+ "merges": self.merges,
+ "max_length": self.max_length,
+ "pad_token_id": self.pad_token_id,
+ }
+
+ def call(self, x, max_length: Optional[int] = None):
+ input_ids = self.tf_tokenizer(x)
+ attention_mask = tf.ones_like(input_ids)
+
+ if self.pad_token_id is not None:
+ # pad the tokens up to max length
+ max_length = max_length if max_length is not None else self.max_length
+
+ if max_length is not None:
+ input_ids, attention_mask = pad_model_inputs(
+ input_ids, max_seq_length=max_length, pad_value=self.pad_token_id
+ )
+
+ return {"attention_mask": attention_mask, "input_ids": input_ids}
+
+
+__all__ = ["TFGPT2Tokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_bigcode/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_bigcode/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..92e985d92734550a5b0635941294669386d35749
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_bigcode/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_gpt_bigcode import *
+ from .modeling_gpt_bigcode import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py
new file mode 100644
index 0000000000000000000000000000000000000000..127a0eed4732c15ef565a306a1a25f86b4e51ce4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py
@@ -0,0 +1,145 @@
+# coding=utf-8
+# Copyright 2023 The BigCode team and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""GPTBigCode configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class GPTBigCodeConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`GPTBigCodeModel`]. It is used to instantiate a
+ GPTBigCode model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the GPTBigCode
+ [gpt_bigcode](https://huggingface.co/gpt_bigcode) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50257):
+ Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`GPTBigCodeModel`].
+ n_positions (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ n_embd (`int`, *optional*, defaults to 768):
+ Dimensionality of the embeddings and hidden states.
+ n_layer (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ n_head (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ n_inner (`int`, *optional*, defaults to None):
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
+ activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new",
+ "gelu_pytorch_tanh"]`.
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ embd_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the embeddings.
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
+ The epsilon to use in the layer normalization layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
+ Scale attention weights by dividing by sqrt(hidden_size)..
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
+ Whether to call the fused softmax in float32.
+ scale_attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
+ Whether to scale the attention softmax in float32.
+ attention_type (`bool`, *optional*, defaults to `True`):
+ Whether to use Multi-Query Attion (`True`) or Multi-Head Attention (`False`).
+ Example:
+
+ ```python
+ >>> from transformers import GPTBigCodeConfig, GPTBigCodeModel
+
+ >>> # Initializing a GPTBigCode configuration
+ >>> configuration = GPTBigCodeConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = GPTBigCodeModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "gpt_bigcode"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {
+ "hidden_size": "n_embd",
+ "max_position_embeddings": "n_positions",
+ "num_attention_heads": "n_head",
+ "num_hidden_layers": "n_layer",
+ }
+
+ def __init__(
+ self,
+ vocab_size=50257,
+ n_positions=1024,
+ n_embd=768,
+ n_layer=12,
+ n_head=12,
+ n_inner=None,
+ activation_function="gelu_pytorch_tanh",
+ resid_pdrop=0.1,
+ embd_pdrop=0.1,
+ attn_pdrop=0.1,
+ layer_norm_epsilon=1e-5,
+ initializer_range=0.02,
+ scale_attn_weights=True,
+ use_cache=True,
+ bos_token_id=50256,
+ eos_token_id=50256,
+ attention_softmax_in_fp32=True,
+ scale_attention_softmax_in_fp32=True,
+ multi_query=True,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.n_positions = n_positions
+ self.n_embd = n_embd
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.n_inner = n_inner
+ self.activation_function = activation_function
+ self.resid_pdrop = resid_pdrop
+ self.embd_pdrop = embd_pdrop
+ self.attn_pdrop = attn_pdrop
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_range = initializer_range
+ self.scale_attn_weights = scale_attn_weights
+ self.use_cache = use_cache
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
+ self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
+ self.multi_query = multi_query
+ self.num_key_value_heads = 1 if multi_query else n_head
+
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+
+__all__ = ["GPTBigCodeConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
new file mode 100644
index 0000000000000000000000000000000000000000..6992dc642a4f024b97a9c143eff434bf4eea205c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
@@ -0,0 +1,931 @@
+# coding=utf-8
+# Copyright 2023 The Bigcode team and HuggingFace Inc. team.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch GPTBigCode model."""
+
+import math
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import is_flash_attn_available
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...utils import (
+ auto_docstring,
+ can_return_tuple,
+ logging,
+)
+from .configuration_gpt_bigcode import GPTBigCodeConfig
+
+
+if is_flash_attn_available():
+ pass
+
+
+logger = logging.get_logger(__name__)
+
+
+# Fused kernels
+# Use separate functions for each case because conditionals prevent kernel fusion.
+# TODO: Could have better fused kernels depending on scaling, dropout and head mask.
+# Is it doable without writing 32 functions?
+@torch.jit.script
+def upcast_masked_softmax(
+ x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
+):
+ input_dtype = x.dtype
+ x = x.to(softmax_dtype) * scale
+ x = torch.where(mask, x, mask_value)
+ x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
+ return x
+
+
+@torch.jit.script
+def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
+ input_dtype = x.dtype
+ x = x.to(softmax_dtype) * scale
+ x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
+ return x
+
+
+@torch.jit.script
+def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
+ x = torch.where(mask, x, mask_value)
+ x = torch.nn.functional.softmax(x, dim=-1)
+ return x
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class GPTBigCodeAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
+ super().__init__()
+ self.config = config
+
+ self.mask_value = None
+ self.multi_query = config.multi_query
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ self.kv_heads = 1 if self.multi_query else self.num_heads
+ self.kv_dim = self.kv_heads * self.head_dim
+ self.num_key_value_groups = self.num_heads // self.kv_heads
+ self.split_size = self.embed_dim
+ self.is_causal = True
+
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ self.scale_attn_weights = config.scale_attn_weights
+ self.scaling = self.head_dim**-0.5 if config.scale_attn_weights else 1.0
+ self.is_cross_attention = is_cross_attention
+
+ self.layer_idx = layer_idx
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
+ self.scale_attention_softmax_in_fp32 = (
+ config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
+ )
+ self.attn_pdrop = config.attn_pdrop
+
+ if self.is_cross_attention:
+ if self.multi_query:
+ raise NotImplementedError("Multi-Query Attention not supported for cross_attention")
+
+ self.c_attn = nn.Linear(self.embed_dim, 2 * self.embed_dim)
+ self.q_attn = nn.Linear(self.embed_dim, self.embed_dim)
+ else:
+ self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim)
+
+ self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ self.attn_dropout = config.attn_pdrop
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ layer_past: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[
+ tuple[torch.Tensor, Optional[torch.Tensor]],
+ tuple[torch.Tensor, Optional[torch.Tensor], tuple[torch.Tensor, ...]],
+ ]:
+ input_shape = hidden_states.shape[:-1]
+
+ if layer_past is not None:
+ if isinstance(layer_past, EncoderDecoderCache):
+ is_updated = layer_past.is_updated.get(self.layer_idx)
+ if self.is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = layer_past.cross_attention_cache
+ else:
+ curr_past_key_value = layer_past.self_attention_cache
+ else:
+ curr_past_key_value = layer_past
+
+ if self.is_cross_attention:
+ if not hasattr(self, "q_attn") or not self.is_cross_attention:
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
+ )
+ if layer_past is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key = curr_past_key_value.layers[self.layer_idx].keys
+ value = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ query = self.q_attn(hidden_states).view(*input_shape, -1, self.head_dim).transpose(1, 2)
+ key, value = self.c_attn(encoder_hidden_states).split((self.head_dim, self.head_dim), dim=-1)
+ else:
+ if self.multi_query:
+ query, key, value = (
+ self.c_attn(hidden_states).unsqueeze(1).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=3)
+ )
+ query = query.view(*input_shape, -1, self.head_dim).transpose(1, 2)
+ else:
+ query, key, value = (
+ self.c_attn(hidden_states)
+ .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
+ .transpose(1, 2)
+ .split(3 * [self.head_dim], dim=3)
+ )
+
+ if layer_past is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not self.is_cross_attention else None
+ key, value = curr_past_key_value.update(key, value, self.layer_idx, {"cache_position": cache_position})
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if self.is_cross_attention:
+ layer_past.is_updated[self.layer_idx] = True
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attn_dropout,
+ scaling=self.scaling,
+ head_mask=head_mask,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+ return attn_output, attn_weights
+
+
+class GPTBigCodeMLP(nn.Module):
+ def __init__(self, intermediate_size, config):
+ super().__init__()
+ embed_dim = config.hidden_size
+ self.c_fc = nn.Linear(embed_dim, intermediate_size)
+ self.c_proj = nn.Linear(intermediate_size, embed_dim)
+ self.act = ACT2FN[config.activation_function]
+ self.dropout = nn.Dropout(config.resid_pdrop)
+
+ # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
+ def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.FloatTensor:
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class GPTBigCodeBlock(nn.Module):
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ hidden_size = config.hidden_size
+ self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
+
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx)
+
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ if config.add_cross_attention:
+ if config.multi_query:
+ raise NotImplementedError("Cross-attention not implemented for MQA")
+
+ self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx)
+
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ self.mlp = GPTBigCodeMLP(self.inner_dim, config)
+
+ def forward(
+ self,
+ hidden_states: Optional[tuple[torch.Tensor]],
+ layer_past: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[
+ tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ ]:
+ residual = hidden_states
+ hidden_states = self.ln_1(hidden_states)
+ attn_outputs = self.attn(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
+ outputs = attn_outputs[1:]
+ # residual connection
+ hidden_states = attn_output + residual
+
+ if encoder_hidden_states is not None:
+ # add one self-attention block for cross-attention
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
+ "cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+ residual = hidden_states
+ hidden_states = self.ln_cross_attn(hidden_states)
+ cross_attn_outputs = self.crossattention(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ attn_output = cross_attn_outputs[0]
+ # residual connection
+ hidden_states = residual + attn_output
+ outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
+
+ residual = hidden_states
+ hidden_states = self.ln_2(hidden_states)
+ feed_forward_hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + feed_forward_hidden_states
+ return (hidden_states,) + outputs
+
+
+@auto_docstring
+class GPTBigCodePreTrainedModel(PreTrainedModel):
+ config: GPTBigCodeConfig
+ base_model_prefix = "transformer"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GPTBigCodeBlock"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)):
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
+ #
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+ module.c_proj.weight.data.normal_(
+ mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
+ )
+ module.c_proj._is_hf_initialized = True
+ elif isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class GPTBigCodeModel(GPTBigCodePreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.multi_query = config.multi_query
+ self.embed_dim = config.hidden_size
+
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
+
+ self.drop = nn.Dropout(config.embd_pdrop)
+ self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ max_positions = config.max_position_embeddings
+ self.register_buffer(
+ "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False
+ )
+
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.wte
+
+ def set_input_embeddings(self, new_embeddings):
+ self.wte = new_embeddings
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ r"""
+ input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size = inputs_embeds.shape[0]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if batch_size <= 0:
+ raise ValueError("batch_size has to be defined and > 0")
+
+ if use_cache and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+ if use_cache and isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ )
+
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = (
+ encoder_attention_mask.bool()
+ if (encoder_attention_mask is not None and 0 in encoder_attention_mask)
+ else None
+ )
+ else:
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if (
+ self.config.add_cross_attention
+ and encoder_hidden_states is not None
+ and encoder_attention_mask is not None
+ ):
+ if encoder_attention_mask.dim() == 2:
+ encoder_attention_mask.unsqueeze(1)
+ assert encoder_attention_mask.dim() == 3
+ encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1)
+ else:
+ encoder_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # head_mask has shape n_layer x batch x n_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ position_embeds = self.wpe(position_ids)
+ hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+ output_shape = input_shape + (hidden_states.size(-1),)
+
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, block in enumerate(self.h):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = block(
+ hidden_states,
+ past_key_values,
+ causal_mask,
+ head_mask[i],
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (outputs[2],)
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """
+)
+class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = GPTBigCodeModel(config)
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.Tensor` of shape `(batch_size, input_ids_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ lm_logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ cross_attentions=transformer_outputs.cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The GPTBigCode Model transformer with a sequence classification head on top (linear layer).
+
+ [`GPTBigCodeForSequenceClassification`] uses the last token in order to do the classification, as other causal
+ models (e.g. GPT-1) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = GPTBigCodeModel(config)
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ **kwargs,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class GPTBigCodeForTokenClassification(GPTBigCodePreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.transformer = GPTBigCodeModel(config)
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
+ classifier_dropout = config.classifier_dropout
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, TokenClassifierOutput]:
+ r"""
+ input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1).to(logits.device))
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+__all__ = [
+ "GPTBigCodeForSequenceClassification",
+ "GPTBigCodeForTokenClassification",
+ "GPTBigCodeForCausalLM",
+ "GPTBigCodeModel",
+ "GPTBigCodePreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neo/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..578577f22882cdc5eea08928e274a18725cf4615
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neo/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_gpt_neo import *
+ from .modeling_flax_gpt_neo import *
+ from .modeling_gpt_neo import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neo/configuration_gpt_neo.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neo/configuration_gpt_neo.py
new file mode 100644
index 0000000000000000000000000000000000000000..875a170277d2048dcadda9cd8f57205a11742797
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neo/configuration_gpt_neo.py
@@ -0,0 +1,273 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""GPT Neo model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+from typing import Any, Optional
+
+from ... import PreTrainedTokenizer, TensorType, is_torch_available
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfigWithPast
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class GPTNeoConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GPTNeoModel`]. It is used to instantiate a GPT
+ Neo model according to the specified arguments, defining the model architecture. Instantiating a configuration with
+ the defaults will yield a similar configuration to that of the GPTNeo
+ [EleutherAI/gpt-neo-1.3B](https://huggingface.co/EleutherAI/gpt-neo-1.3B) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50257):
+ Vocabulary size of the GPT Neo model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`GPTNeoModel`]. Vocabulary size of the model. Defines the different
+ tokens that can be represented by the *inputs_ids* passed to the forward method of [`GPTNeoModel`].
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_layers (`int`, *optional*, defaults to 24):
+ Number of hidden layers in the Transformer encoder.
+ attention_types (`List`, *optional*, defaults to `[[['global', 'local'], 12]]`):
+ The type of attention for each layer in a `List` of the following format `[[["attention_type"],
+ num_layerss]]` e.g. for a 24 layer model `[[["global"], 24]]` or `[[["global", "local"], 12]]` Choose the
+ value of `attention_type` from `["global", "local"]`
+ num_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 8192):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ window_size (`int`, *optional*, defaults to 256):
+ The size of the sliding window for local attention.
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu_new"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ resid_dropout (`float`, *optional*, defaults to 0.0):
+ Residual dropout used in the attention pattern.
+ embed_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ classifier_dropout (`float`, *optional*, defaults to 0.1):
+ Argument used when doing token classification, used in the model [`GPTNeoForTokenClassification`]. The
+ dropout ratio for the hidden layer.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ bos_token_id (`int`, *optional*, defaults to 50256):
+ The id of the beginning of sentence token in the vocabulary.
+ eos_token_id (`int`, *optional*, defaults to 50256):
+ The id of the end of sentence token in the vocabulary.
+
+ Example:
+
+ ```python
+ >>> from transformers import GPTNeoConfig, GPTNeoModel
+
+ >>> # Initializing a GPTNeo EleutherAI/gpt-neo-1.3B style configuration
+ >>> configuration = GPTNeoConfig()
+
+ >>> # Initializing a model (with random weights) from the EleutherAI/gpt-neo-1.3B style configuration
+ >>> model = GPTNeoModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "gpt_neo"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
+
+ def __init__(
+ self,
+ vocab_size=50257,
+ max_position_embeddings=2048,
+ hidden_size=2048,
+ num_layers=24,
+ attention_types=[[["global", "local"], 12]],
+ num_heads=16,
+ intermediate_size=None,
+ window_size=256,
+ activation_function="gelu_new",
+ resid_dropout=0.0,
+ embed_dropout=0.0,
+ attention_dropout=0.0,
+ classifier_dropout=0.1,
+ layer_norm_epsilon=1e-5,
+ initializer_range=0.02,
+ use_cache=True,
+ bos_token_id=50256,
+ eos_token_id=50256,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.num_heads = num_heads
+ self.intermediate_size = intermediate_size
+ self.window_size = window_size
+ self.activation_function = activation_function
+ self.resid_dropout = resid_dropout
+ self.embed_dropout = embed_dropout
+ self.attention_dropout = attention_dropout
+ self.classifier_dropout = classifier_dropout
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_range = initializer_range
+ self.use_cache = use_cache
+
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+
+ self.attention_types = attention_types
+ self.attention_layers = self.expand_attention_types_params(attention_types)
+
+ if len(self.attention_layers) != self.num_layers:
+ raise ValueError(
+ "Configuration for convolutional module is incorrect. "
+ "It is required that `len(config.attention_layers)` == `config.num_layers` "
+ f"but is `len(config.attention_layers) = {len(self.attention_layers)}`, "
+ f"`config.num_layers = {self.num_layers}`. "
+ "`config.attention_layers` is prepared using `config.attention_types`. "
+ "Please verify the value of `config.attention_types` argument."
+ )
+
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ @staticmethod
+ def expand_attention_types_params(attention_types):
+ attentions = []
+ for item in attention_types:
+ for _ in range(item[1]):
+ attentions.extend(item[0])
+ return attentions
+
+
+def custom_unfold(input, dimension, size, step):
+ """Custom torch.Tensor.unfold implementation to enable the export to ONNX."""
+ import torch
+
+ shape = input.size()
+ rank = len(shape)
+ sizedim = shape[dimension]
+
+ low_indices = torch.arange(0, sizedim, step)
+ min_length = torch.div(sizedim - size, step, rounding_mode="floor") + 1
+ indices = torch.arange(size) + low_indices[:min_length][:, None]
+
+ s = [slice(None)] * rank
+ s[dimension] = indices
+ sliced = input[s]
+
+ perm = list(range(0, rank + 1))
+ perm.append(perm.pop(dimension + 1))
+
+ return sliced.permute(perm)
+
+
+def custom_get_block_length_and_num_blocks(seq_length, window_size):
+ """
+ Custom implementation for GPTNeoAttentionMixin._get_block_length_and_num_blocks to enable the export to ONNX as
+ original implementation uses Python variables and control flow.
+ """
+ import torch
+
+ candidates = torch.arange(1, window_size)
+ remainders = torch.remainder(seq_length, candidates)
+ divisor_indices = remainders == 0
+ divisors = candidates[divisor_indices]
+ largest_divisor = torch.max(divisors)
+ return largest_divisor, torch.div(seq_length, largest_divisor, rounding_mode="floor")
+
+
+class GPTNeoOnnxConfig(OnnxConfigWithPast):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
+ if self.use_past:
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
+ common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
+ else:
+ common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
+
+ return common_inputs
+
+ @property
+ def num_attention_heads(self) -> int:
+ return self._config.num_heads
+
+ def generate_dummy_inputs(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ batch_size: int = -1,
+ seq_length: int = -1,
+ is_pair: bool = False,
+ framework: Optional[TensorType] = None,
+ ) -> Mapping[str, Any]:
+ common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
+ tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
+ )
+
+ # We need to order the input in the way they appears in the forward()
+ ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
+
+ # Need to add the past_keys
+ if self.use_past:
+ if not is_torch_available():
+ raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
+ else:
+ import torch
+
+ batch, seqlen = common_inputs["input_ids"].shape
+ # Not using the same length for past_key_values
+ past_key_values_length = seqlen + 2
+ past_shape = (
+ batch,
+ self.num_attention_heads,
+ past_key_values_length,
+ self._config.hidden_size // self.num_attention_heads,
+ )
+ ordered_inputs["past_key_values"] = [
+ (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
+ ]
+
+ ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
+ if self.use_past:
+ mask_dtype = ordered_inputs["attention_mask"].dtype
+ ordered_inputs["attention_mask"] = torch.cat(
+ [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
+ )
+
+ return ordered_inputs
+
+ @property
+ def default_onnx_opset(self) -> int:
+ return 13
+
+
+__all__ = ["GPTNeoConfig", "GPTNeoOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neo/modeling_flax_gpt_neo.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neo/modeling_flax_gpt_neo.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6cdc50b359b0415fd165b05311eeb9bc07c7526
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neo/modeling_flax_gpt_neo.py
@@ -0,0 +1,687 @@
+# coding=utf-8
+# Copyright 2021 The Eleuther AI and The Google Flax Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+from typing import Optional
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+
+from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
+from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from .configuration_gpt_neo import GPTNeoConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "GPTNeoConfig"
+_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neo-1.3B"
+
+
+GPT_NEO_START_DOCSTRING = r"""
+
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`GPTNeoConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+GPT_NEO_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class FlaxGPTNeoSelfAttention(nn.Module):
+ config: GPTNeoConfig
+ attention_type: str
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ config = self.config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and "
+ f"`num_heads`: {self.num_heads})."
+ )
+
+ self.attn_dropout = nn.Dropout(config.attention_dropout)
+ self.resid_dropout = nn.Dropout(config.resid_dropout)
+
+ dense = partial(
+ nn.Dense,
+ self.embed_dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ )
+
+ self.q_proj, self.k_proj, self.v_proj = dense(use_bias=False), dense(use_bias=False), dense(use_bias=False)
+ self.out_proj = dense()
+
+ self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
+ if self.attention_type == "local":
+ self.causal_mask = self.causal_mask ^ jnp.tril(self.causal_mask, -config.window_size)
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
+
+ @nn.compact
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slightly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ query = self.q_proj(hidden_states) * jnp.sqrt(self.head_dim).astype(self.dtype)
+ key = self.k_proj(hidden_states)
+ value = self.v_proj(hidden_states)
+
+ query = self._split_heads(query)
+ key = self._split_heads(key)
+ value = self._split_heads(value)
+
+ query_length, key_length = query.shape[1], key.shape[1]
+
+ if self.has_variable("cache", "cached_key"):
+ mask_shift = self.variables["cache"]["cache_index"]
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_mask = lax.dynamic_slice(
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
+ )
+ else:
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+
+ batch_size = hidden_states.shape[0]
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_mask)
+
+ dropout_rng = None
+ if not deterministic and self.config.attention_dropout > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.has_variable("cache", "cached_key") or init_cache:
+ key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
+
+ # transform boolean mask into float mask
+ attention_bias = lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
+ )
+
+ # usual dot product attention
+ attn_weights = dot_product_attention_weights(
+ query,
+ key,
+ bias=attention_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.config.attention_dropout,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ precision=None,
+ )
+
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
+ attn_output = self._merge_heads(attn_output)
+ attn_output = self.out_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
+
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
+ return outputs
+
+
+class FlaxGPTNeoAttention(nn.Module):
+ config: GPTNeoConfig
+ layer_id: int = 0
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ attention_type = self.config.attention_layers[self.layer_id]
+ self.attention = FlaxGPTNeoSelfAttention(self.config, attention_type, dtype=self.dtype)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ return self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ )
+
+
+class FlaxGPTNeoMLP(nn.Module):
+ config: GPTNeoConfig
+ intermediate_size: int
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ embed_dim = self.config.hidden_size
+ kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
+ self.c_fc = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init)
+ self.c_proj = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init)
+ self.act = ACT2FN[self.config.activation_function]
+ self.dropout = nn.Dropout(rate=self.config.resid_dropout)
+
+ def __call__(self, hidden_states, deterministic: bool = True):
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ return hidden_states
+
+
+class FlaxGPTNeoBlock(nn.Module):
+ config: GPTNeoConfig
+ layer_id: int = 0
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ hidden_size = self.config.hidden_size
+ inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * hidden_size
+
+ self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+ self.attn = FlaxGPTNeoAttention(self.config, layer_id=self.layer_id, dtype=self.dtype)
+ self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+ self.mlp = FlaxGPTNeoMLP(self.config, inner_dim, dtype=self.dtype)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ residual = hidden_states
+ hidden_states = self.ln_1(hidden_states)
+ outputs = self.attn(
+ hidden_states,
+ attention_mask=attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ )
+ # residual connection
+ attn_output = outputs[0]
+ hidden_states = attn_output + residual
+
+ residual = hidden_states
+ hidden_states = self.ln_2(hidden_states)
+ feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
+ # residual connection
+ hidden_states = residual + feed_forward_hidden_states
+
+ return (hidden_states,) + outputs[1:]
+
+
+class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = GPTNeoConfig
+ base_model_prefix = "transformer"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: GPTNeoConfig,
+ input_shape: tuple = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ def init_cache(self, batch_size, max_length):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ """
+ # init input variables to retrieve cache
+ input_ids = jnp.ones((batch_size, max_length))
+ attention_mask = jnp.ones_like(input_ids)
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
+ )
+ return unfreeze(init_variables["cache"])
+
+ @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ position_ids=None,
+ params: Optional[dict] = None,
+ past_key_values: Optional[dict] = None,
+ dropout_rng: jax.random.PRNGKey = None,
+ train: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ batch_size, sequence_length = input_ids.shape
+
+ if position_ids is None:
+ if past_key_values is not None:
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
+
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ if attention_mask is None:
+ attention_mask = jnp.ones((batch_size, sequence_length))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTNeoAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ jnp.array(position_ids, dtype="i4"),
+ not train,
+ False,
+ output_attentions,
+ output_hidden_states,
+ return_dict,
+ rngs=rngs,
+ mutable=mutable,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past_key_values = outputs
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past_key_values = outputs
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
+
+ return outputs
+
+
+class FlaxGPTNeoBlockCollection(nn.Module):
+ config: GPTNeoConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.blocks = [
+ FlaxGPTNeoBlock(self.config, layer_id=i, name=str(i), dtype=self.dtype)
+ for i in range(self.config.num_hidden_layers)
+ ]
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ for block in self.blocks:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = block(
+ hidden_states,
+ attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions += (layer_outputs[1],)
+
+ # this contains possible `None` values - `FlaxGPTNeoModule` will filter them out
+ outputs = (hidden_states, all_hidden_states, all_attentions)
+
+ return outputs
+
+
+class FlaxGPTNeoModule(nn.Module):
+ config: GPTNeoConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.embed_dim = self.config.hidden_size
+ embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range)
+ self.wte = nn.Embed(
+ self.config.vocab_size,
+ self.embed_dim,
+ embedding_init=embedding_init,
+ )
+ self.wpe = nn.Embed(
+ self.config.max_position_embeddings,
+ self.embed_dim,
+ embedding_init=embedding_init,
+ )
+ self.dropout = nn.Dropout(rate=self.config.embed_dropout)
+ self.h = FlaxGPTNeoBlockCollection(self.config, dtype=self.dtype)
+ self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ deterministic=True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ input_embeds = self.wte(input_ids.astype("i4"))
+ position_embeds = self.wpe(position_ids.astype("i4"))
+
+ hidden_states = input_embeds + position_embeds
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+
+ outputs = self.h(
+ hidden_states,
+ attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = outputs[0]
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = outputs[1] + (hidden_states,)
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
+ else:
+ outputs = (hidden_states,) + outputs[1:]
+
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=outputs[1],
+ attentions=outputs[-1],
+ )
+
+
+@add_start_docstrings(
+ "The bare GPTNeo Model transformer outputting raw hidden-states without any specific head on top.",
+ GPT_NEO_START_DOCSTRING,
+)
+class FlaxGPTNeoModel(FlaxGPTNeoPreTrainedModel):
+ module_class = FlaxGPTNeoModule
+
+
+append_call_sample_docstring(FlaxGPTNeoModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)
+
+
+class FlaxGPTNeoForCausalLMModule(nn.Module):
+ config: GPTNeoConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.transformer = FlaxGPTNeoModule(self.config, dtype=self.dtype)
+ self.lm_head = nn.Dense(
+ self.config.vocab_size,
+ use_bias=False,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ outputs = self.transformer(
+ input_ids,
+ attention_mask,
+ position_ids,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ if self.config.tie_word_embeddings:
+ shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
+ else:
+ lm_logits = self.lm_head(hidden_states)
+
+ if not return_dict:
+ return (lm_logits,) + outputs[1:]
+
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
+
+
+@add_start_docstrings(
+ """
+ The GPTNeo Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """,
+ GPT_NEO_START_DOCSTRING,
+)
+class FlaxGPTNeoForCausalLM(FlaxGPTNeoPreTrainedModel):
+ module_class = FlaxGPTNeoForCausalLMModule
+
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
+ # initializing the cache
+ batch_size, seq_length = input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since GPTNeo uses a causal mask, those positions are masked anyways.
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if attention_mask is not None:
+ position_ids = attention_mask.cumsum(axis=-1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "attention_mask": extended_attention_mask,
+ "position_ids": position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
+ return model_kwargs
+
+
+append_call_sample_docstring(FlaxGPTNeoForCausalLM, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC)
+
+
+__all__ = ["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neo/modeling_gpt_neo.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neo/modeling_gpt_neo.py
new file mode 100644
index 0000000000000000000000000000000000000000..69d74565745a578412af97ba686dacde513f0fb9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neo/modeling_gpt_neo.py
@@ -0,0 +1,1192 @@
+# coding=utf-8
+# Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch GPT Neo model."""
+
+import os
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
+from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ CausalLMOutputWithPast,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ auto_docstring,
+ is_torch_flex_attn_available,
+ logging,
+)
+from .configuration_gpt_neo import GPTNeoConfig
+
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+if is_flash_attn_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
+# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
+# It means that the function will not be traced through and simply appear as a node in the graph.
+_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
+
+
+logger = logging.get_logger(__name__)
+
+
+def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
+ """Load tf checkpoints in a pytorch model"""
+ try:
+ import re
+
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(gpt_neo_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ if "global_step" not in name and "adam" not in name:
+ array = tf.train.load_variable(tf_path, name)
+ array = tf.dtypes.cast(array.squeeze(), tf.float32).numpy()
+ name = name.replace("attn/q", "attn/attention/q_proj/w")
+ name = name.replace("attn/k", "attn/attention/k_proj/w")
+ name = name.replace("attn/v", "attn/attention/v_proj/w")
+ name = name.replace("attn/o", "attn/attention/out_proj/w")
+ name = name.replace("norm_1", "ln_1")
+ name = name.replace("norm_2", "ln_2")
+ name = name.replace("attn/compute_output_bias/o_b", "attn/attention/out_proj/b")
+ name = name.replace("conv1d_main/c_fc/kernel", "c_fc/w")
+ name = name.replace("conv1d_main/c_fc/bias", "c_fc/b")
+ name = name.replace("conv1d_main/c_proj/kernel", "c_proj/w")
+ name = name.replace("conv1d_main/c_proj/bias", "c_proj/b")
+
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ name = name[5:] # skip "gpt2/"
+ name = name.split("/")
+ pointer = model.transformer
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
+ scope_names = re.split(r"(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "w" or scope_names[0] == "g":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "b":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
+ pointer = getattr(pointer, scope_names[0])
+ pointer = getattr(pointer, "weight")
+ else:
+ pointer = getattr(pointer, scope_names[0])
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+
+ if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]:
+ array = array.transpose()
+
+ if name == ["wte"]:
+ # if vocab is padded, then trim off the padding embeddings
+ array = array[: config.vocab_size]
+
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched {name}")
+
+ print(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+
+ # init the final linear layer using word embeddings
+ embs = model.transformer.wte.weight
+ lin = nn.Linear(embs.size()[1], embs.size()[0], bias=False)
+ lin.weight = embs
+ model.set_output_embeddings(lin)
+ return model
+
+
+class GPTNeoSelfAttention(nn.Module):
+ def __init__(self, config, attention_type, layer_id=None):
+ super().__init__()
+ self.config = config
+
+ max_positions = config.max_position_embeddings
+ bias = torch.tril(torch.ones((max_positions, max_positions), dtype=bool)).view(
+ 1, 1, max_positions, max_positions
+ )
+
+ # local causal self attention is a sliding window where each token can only attend to the previous
+ # window_size tokens. This is implemented by updating the causal mask such that for each token
+ # all other tokens are masked except the previous window_size tokens.
+ if attention_type == "local":
+ bias = torch.bitwise_xor(bias, torch.tril(bias, -config.window_size))
+
+ self.register_buffer("bias", bias, persistent=False)
+ self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
+
+ self.attn_dropout = nn.Dropout(float(config.attention_dropout))
+ self.resid_dropout = nn.Dropout(float(config.resid_dropout))
+ self.is_causal = True
+ self.layer_id = layer_id
+
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
+
+ def _split_heads(self, tensor, num_heads, attn_head_size):
+ """
+ Splits hidden_size dim into attn_head_size and num_heads
+ """
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
+ tensor = tensor.view(new_shape)
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
+
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
+ """
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
+ """
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
+ return tensor.view(new_shape)
+
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
+ # Keep the attention weights computation in fp32 to avoid overflow issues
+ query = query.to(torch.float32)
+ key = key.to(torch.float32)
+
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
+
+ # Apply sliding window masking for local attention layers
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+ attn_weights = attn_weights.to(value.dtype)
+ attn_weights = self.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+
+ return attn_output, attn_weights
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ layer_past=None,
+ head_mask=None,
+ use_cache=False,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ query = self.q_proj(hidden_states)
+ key = self.k_proj(hidden_states)
+ value = self.v_proj(hidden_states)
+
+ query = self._split_heads(query, self.num_heads, self.head_dim)
+ key = self._split_heads(key, self.num_heads, self.head_dim)
+ value = self._split_heads(value, self.num_heads, self.head_dim)
+
+ if layer_past is not None:
+ cache_kwargs = {"cache_position": cache_position}
+ key, value = layer_past.update(key, value, self.layer_id, cache_kwargs)
+
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
+
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
+ attn_output = self.out_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+
+ return attn_output, attn_weights
+
+
+class GPTNeoFlashAttention2(GPTNeoSelfAttention):
+ """
+ GPTNeo flash attention module. This module inherits from `GPTNeoSelfAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ layer_past=None,
+ head_mask=None,
+ use_cache=False,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ bsz, _, _ = hidden_states.size()
+
+ query = self.q_proj(hidden_states)
+ key = self.k_proj(hidden_states)
+ value = self.v_proj(hidden_states)
+
+ query = self._split_heads(query, self.num_heads, self.head_dim)
+ key = self._split_heads(key, self.num_heads, self.head_dim)
+ value = self._split_heads(value, self.num_heads, self.head_dim)
+
+ if layer_past is not None:
+ cache_kwargs = {"cache_position": cache_position}
+ key, value = layer_past.update(key, value, self.layer_id, cache_kwargs)
+
+ query_length = query.shape[2]
+ tgt_len = key.shape[2]
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim)
+ key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
+ value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
+
+ attn_dropout = self.config.attention_dropout if self.training else 0.0
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ device_type = query.device.type if query.device.type != "mps" else "cpu"
+ if query.dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = (
+ torch.get_autocast_dtype(device_type)
+ if hasattr(torch, "get_autocast_dtype")
+ else torch.get_autocast_gpu_dtype()
+ )
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query = query.to(target_dtype)
+ key = key.to(target_dtype)
+ value = value.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query,
+ key,
+ value,
+ attention_mask,
+ query_length,
+ dropout=attn_dropout,
+ softmax_scale=1.0,
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
+ attn_output = self.out_proj(attn_weights_reshaped)
+ attn_output = self.resid_dropout(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+GPT_NEO_ATTENTION_CLASSES = {
+ "eager": GPTNeoSelfAttention,
+ "flash_attention_2": GPTNeoFlashAttention2,
+}
+
+
+class GPTNeoAttention(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.layer_id = layer_id
+ self.attention_layers = config.attention_layers
+ self.attention_type = self.attention_layers[layer_id]
+
+ if self.attention_type in ["global", "local"]:
+ self.attention = GPT_NEO_ATTENTION_CLASSES[config._attn_implementation](
+ config, self.attention_type, layer_id
+ )
+ else:
+ raise NotImplementedError(
+ "Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: "
+ f"{config.attention_layers}. Select attn layer types from ['global', 'local'] only."
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ layer_past=None,
+ attention_mask=None,
+ head_mask=None,
+ use_cache=False,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ return self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ layer_past=layer_past,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+
+
+class GPTNeoMLP(nn.Module):
+ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * hidden_size
+ super().__init__()
+ embed_dim = config.hidden_size
+ self.c_fc = nn.Linear(embed_dim, intermediate_size)
+ self.c_proj = nn.Linear(intermediate_size, embed_dim)
+ self.act = ACT2FN[config.activation_function]
+ self.dropout = nn.Dropout(float(config.resid_dropout))
+
+ def forward(self, hidden_states):
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class GPTNeoBlock(GradientCheckpointingLayer):
+ def __init__(self, config, layer_id=None):
+ super().__init__()
+ hidden_size = config.hidden_size
+ inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+ self.attn = GPTNeoAttention(config, layer_id)
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+ self.mlp = GPTNeoMLP(inner_dim, config)
+
+ def forward(
+ self,
+ hidden_states,
+ layer_past=None,
+ attention_mask=None,
+ head_mask=None,
+ use_cache=False,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ residual = hidden_states
+ hidden_states = self.ln_1(hidden_states)
+ attn_output, attn_weights = self.attn(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+
+ # residual connection
+ hidden_states = attn_output + residual
+
+ residual = hidden_states
+ hidden_states = self.ln_2(hidden_states)
+ feed_forward_hidden_states = self.mlp(hidden_states)
+ # residual connection
+ hidden_states = residual + feed_forward_hidden_states
+
+ return hidden_states, attn_weights
+
+
+@auto_docstring
+class GPTNeoPreTrainedModel(PreTrainedModel):
+ config: GPTNeoConfig
+ load_tf_weights = load_tf_weights_in_gpt_neo
+ base_model_prefix = "transformer"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GPTNeoBlock"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _can_compile_fullgraph = False # TODO: needs a hybrid cache
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, (nn.Linear,)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class GPTNeoModel(GPTNeoPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embed_dim = config.hidden_size
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
+ self.drop = nn.Dropout(float(config.embed_dropout))
+ self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.wte
+
+ def set_input_embeddings(self, new_embeddings):
+ self.wte = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Union[Cache, tuple[torch.FloatTensor]]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ seq_length = inputs_embeds.shape[1]
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x num_heads x N x N
+ # head_mask has shape n_layer x batch x num_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
+ position_embeds = self.wpe(position_ids)
+ hidden_states = inputs_embeds + position_embeds
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, seq_length)
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+ output_shape = (-1, seq_length, hidden_states.size(-1))
+
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, block in enumerate(self.h):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = block(
+ hidden_states,
+ layer_past=past_key_values,
+ attention_mask=causal_mask,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[1],)
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, "BlockMask"],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+@auto_docstring(
+ custom_intro="""
+ The GPT Neo Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """
+)
+class GPTNeoForCausalLM(GPTNeoPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = GPTNeoModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Union[Cache, tuple[torch.FloatTensor]]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(lm_logits.device)
+ # Compute loss in fp32 to match with mesh-tf version
+ # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
+ lm_logits = lm_logits.to(torch.float32)
+
+ # Flatten the tokens
+ loss = self.loss_function(
+ lm_logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ lm_logits = lm_logits.to(hidden_states.dtype)
+ loss = loss.to(hidden_states.dtype)
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The GPTNeo Model transformer with a sequence classification head on top (linear layer).
+
+ [`GPTNeoForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-1) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = GPTNeoModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Union[Cache, tuple[torch.FloatTensor]]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class GPTNeoForTokenClassification(GPTNeoPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.transformer = GPTNeoModel(config)
+ self.dropout = nn.Dropout(config.classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, TokenClassifierOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class GPTNeoForQuestionAnswering(GPTNeoPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = GPTNeoModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, QuestionAnsweringModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "GPTNeoForCausalLM",
+ "GPTNeoForQuestionAnswering",
+ "GPTNeoForSequenceClassification",
+ "GPTNeoForTokenClassification",
+ "GPTNeoModel",
+ "GPTNeoPreTrainedModel",
+ "load_tf_weights_in_gpt_neo",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neox_japanese/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neox_japanese/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..94ba39d69ad638c706f6ac8491e2dea80e269929
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neox_japanese/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_gpt_neox_japanese import *
+ from .modeling_gpt_neox_japanese import *
+ from .tokenization_gpt_neox_japanese import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py
new file mode 100644
index 0000000000000000000000000000000000000000..320157334539b1d7c418c8cf97c8b57dc38629f7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py
@@ -0,0 +1,167 @@
+# coding=utf-8
+# Copyright 2022 ABEJA, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""GPTNeoX Japanese model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class GPTNeoXJapaneseConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GPTNeoXModelJapanese`]. It is used to instantiate
+ a GPTNeoX model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the GPTNeoXJapanese
+ [abeja/gpt-neox-japanese-2.7b](https://huggingface.co/abeja/gpt-neox-japanese-2.7b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information. Default configs is set as 2.7B model
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the GPTNeoXJapanese model. Defines the number of different tokens that can be
+ represented by the `inputs_ids` passed when calling [`GPTNeoXJapanese`].
+ hidden_size (`int`, *optional*, defaults to 2560):
+ Dimension of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_multiple_size (`int`, *optional*, defaults to 4):
+ Dimension of the "intermediate" layer in the Transformer encoder is calculated by hidden_size *
+ intermediate_multiple_size.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler.
+ rotary_pct (`float`, *optional*, defaults to 1.00):
+ percentage of hidden dimensions to allocate to rotary embeddings
+ rotary_emb_base (`int`, *optional*, defaults to 10000)
+ base for computing rotary embeddings frequency
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention.
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the hidden layer.
+ Example:
+
+ ```python
+ >>> from transformers import GPTNeoXJapaneseConfig, GPTNeoXJapaneseModel
+
+ >>> # Initializing a GPTNeoXJapanese gpt-neox-japanese-2.7b style configuration
+ >>> configuration = GPTNeoXJapaneseConfig()
+
+ >>> # Initializing a model (with random weights) from the gpt-neox-japanese-2.7b style configuration
+ >>> model = GPTNeoXJapaneseModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "gpt_neox_japanese"
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=2560,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ intermediate_multiple_size=4,
+ hidden_act="gelu",
+ rotary_pct=1.00,
+ rotary_emb_base=10000,
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ use_cache=True,
+ bos_token_id=31996,
+ eos_token_id=31999,
+ rope_scaling=None,
+ attention_dropout=0.1,
+ hidden_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_multiple_size = intermediate_multiple_size
+ self.hidden_act = hidden_act
+ self.rotary_pct = rotary_pct
+ self.partial_rotary_factor = rotary_pct
+ self.rotary_emb_base = rotary_emb_base
+ self.rope_theta = rotary_emb_base
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.use_cache = use_cache
+ self.rope_scaling = rope_scaling
+ self.attention_dropout = attention_dropout
+ self.hidden_dropout = hidden_dropout
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+
+__all__ = ["GPTNeoXJapaneseConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
new file mode 100644
index 0000000000000000000000000000000000000000..70399f376c7553fc5d7a1437b0a4b760732697ba
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
@@ -0,0 +1,755 @@
+# coding=utf-8
+# Copyright 2022 ABEJA, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch GPTNeoX model."""
+
+import math
+from typing import Optional, Union
+
+import torch
+from torch import Tensor, nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
+from .configuration_gpt_neox_japanese import GPTNeoXJapaneseConfig
+
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+@auto_docstring
+class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
+ config: GPTNeoXJapaneseConfig
+ base_model_prefix = "gpt_neox_japanese"
+ _no_split_modules = ["GPTNeoXJapaneseLayer"]
+ _skip_keys_device_placement = "past_key_values"
+
+ _can_compile_fullgraph = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, GPTNeoXJapaneseAttention):
+ if module.dense_bias is not None:
+ module.dense_bias.data.zero_()
+
+
+class GPTNeoXJapaneseAttention(nn.Module):
+ def __init__(self, config, use_bias=False, layer_idx=None):
+ super().__init__()
+ self.num_attention_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_attention_heads
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.layer_idx = layer_idx
+ self.rotary_ndims = int(self.head_size * config.rotary_pct)
+ self.rope_theta = config.rotary_emb_base
+ self.rotary_emb = GPTNeoXJapaneseRotaryEmbedding(config=config)
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
+ self.norm_factor = math.sqrt(self.head_size)
+
+ self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False)
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ # Activate bias if the last layer
+ self.use_bias = use_bias
+ self.dense_bias = nn.Parameter(torch.zeros(config.hidden_size)) if use_bias else None
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: torch.FloatTensor,
+ position_ids: torch.LongTensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ layer_past: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ ):
+ # Compute QKV
+ # Attention heads [batch, seq_len, hidden_size]
+ # --> [batch, seq_len, (np * 3 * head_size)]
+ qkv = self.query_key_value(hidden_states)
+
+ # [batch, seq_len, (num_heads * 3 * head_size)]
+ # --> [batch, seq_len, num_heads, 3 * head_size]
+ new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
+ qkv = qkv.view(*new_qkv_shape)
+
+ # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
+ query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
+ key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
+ value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
+
+ # Compute rotary embeddings on rotary_ndims
+ query_rot = query[..., : self.rotary_ndims]
+ query_pass = query[..., self.rotary_ndims :]
+ key_rot = key[..., : self.rotary_ndims]
+ key_pass = key[..., self.rotary_ndims :]
+
+ cos, sin = position_embeddings
+ query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
+ query = torch.cat((query, query_pass), dim=-1).contiguous()
+ key = torch.cat((key, key_pass), dim=-1).contiguous()
+
+ # Cache QKV values
+ if layer_past is not None:
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "partial_rotation_size": self.rotary_ndims,
+ "cache_position": cache_position,
+ }
+ key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
+
+ # Compute attention
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
+
+ # Reshape outputs
+ attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
+ attn_output = self.dense(attn_output)
+
+ return attn_output, attn_weights, self.dense_bias
+
+ @classmethod
+ def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
+ """
+ Splits hidden dim into attn_head_size and num_attention_heads
+ """
+ # tensor: [bs, seq_len, hidden_size]
+ new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
+ # -> [bs, seq_len, num_attention_heads, attn_head_size]
+ tensor = tensor.view(new_shape)
+ # -> [bs, num_attention_heads, seq_len, attn_head_size]
+ tensor = tensor.permute(0, 2, 1, 3)
+ return tensor
+
+ @classmethod
+ def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
+ """
+ Merges attn_head_size dim and num_attn_heads dim into hidden dim
+ """
+ # tensor [bs, num_attention_heads, seq_len, attn_head_size]
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ # -> [bs, seq_len, num_attention_heads, attn_head_size]
+ tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size)
+ # -> [bs, seq_len, hidden_size]
+ return tensor
+
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
+ # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
+ # compute causal mask from causal mask buffer
+ batch_size, num_attention_heads, query_length, attn_head_size = query.size()
+ key_length = key.size(-2)
+
+ query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
+ key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
+
+ # [batch_size * num_heads, q_length, kv_length]
+ attn_scores = torch.zeros(
+ batch_size * num_attention_heads,
+ query_length,
+ key_length,
+ dtype=query.dtype,
+ device=key.device,
+ )
+ attention_scores = torch.baddbmm(
+ attn_scores,
+ query,
+ key.transpose(1, 2),
+ beta=1.0,
+ alpha=1.0 / self.norm_factor,
+ )
+
+ attention_scores = attention_scores.view(batch_size, num_attention_heads, query_length, -1)
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
+ attention_scores = attention_scores + causal_mask
+
+ attn_weights = nn.functional.softmax(attention_scores, dim=-1)
+ attn_weights = self.attention_dropout(attn_weights)
+ attn_weights = attn_weights.to(value.dtype)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoX->GPTNeoXJapanese
+class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: GPTNeoXJapaneseConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def bias_dropout_add(x: Tensor, bias: Tensor, residual: Optional[Tensor], prob: float, training: bool) -> Tensor:
+ """add bias to x, apply dropout and residual connection
+
+ Args:
+ x (Tensor): main path of output
+ bias (Tensor): None or attn_bias of the last attention layer
+ residual (Optional[Tensor]): residual value
+ prob (float): dropout probability
+ training (bool): whether in training mode or not
+
+ Returns:
+ Tensor: dropout(x + bias) + residual
+ """
+ if bias is not None:
+ x = x + bias
+ out = torch.nn.functional.dropout(x, p=prob, training=training)
+ if residual is not None:
+ out = residual + out
+ return out
+
+
+class GPTNeoXJapaneseMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ intermediate_size = int(config.hidden_size * config.intermediate_multiple_size)
+ self.dense_h_to_4h = nn.Linear(config.hidden_size, intermediate_size, bias=False)
+ # Project back to h.
+ self.dense_4h_to_h = nn.Linear(intermediate_size, config.hidden_size, bias=False)
+ self.act = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states):
+ intermediate = self.dense_h_to_4h(hidden_states)
+ intermediate = self.act(intermediate)
+ output = self.dense_4h_to_h(intermediate)
+ return output
+
+
+class GPTNeoXJapaneseLayer(nn.Module):
+ def __init__(self, config, layer_number):
+ super().__init__()
+ self.layer_number = layer_number
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ # activate bias only last layer
+ self.attention = GPTNeoXJapaneseAttention(
+ config=config, use_bias=layer_number == config.num_hidden_layers - 1, layer_idx=layer_number
+ )
+ self.mlp = GPTNeoXJapaneseMLP(config)
+ self.hidden_dropout = config.hidden_dropout
+
+ def forward(
+ self,
+ hidden_states: Optional[torch.FloatTensor],
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ layer_past: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ ):
+ residual = hidden_states
+ ln_out = self.input_layernorm(hidden_states)
+ attn_output, attn_weights, attn_bias = self.attention(
+ ln_out,
+ attention_mask=attention_mask,
+ layer_past=layer_past,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ # attn_output = (atten_output + bias) + residual
+ attn_output = bias_dropout_add(
+ attn_output,
+ bias=attn_bias.expand_as(residual) if attn_bias is not None else attn_bias,
+ residual=residual,
+ prob=self.hidden_dropout,
+ training=self.training,
+ )
+ mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
+
+ # attn_output = (mlp_output + mlp_bias) + atten_output
+ attn_output = bias_dropout_add(
+ mlp_output, bias=None, residual=attn_output, prob=self.hidden_dropout, training=self.training
+ )
+
+ return attn_output, attn_weights
+
+
+@auto_docstring
+class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.layers = nn.ModuleList(
+ [GPTNeoXJapaneseLayer(config=config, layer_number=i) for i in range(config.num_hidden_layers)]
+ )
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.rotary_emb = GPTNeoXJapaneseRotaryEmbedding(config=config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_in
+
+ def set_input_embeddings(self, value):
+ self.embed_in = value
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Union[Cache, tuple[tuple[torch.FloatTensor]]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, GPTNeoXJapaneseModel
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("abeja/gpt-neox-japanese-2.7b")
+ >>> model = GPTNeoXJapaneseModel.from_pretrained("abeja/gpt-neox-japanese-2.7b")
+
+ >>> inputs = tokenizer("日本語のGPT-neoxがHugging Faceで使えます😀", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_in(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ seq_length = inputs_embeds.shape[1]
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ head_mask=head_mask[i],
+ layer_past=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+ hidden_states = outputs[0]
+ if output_attentions:
+ all_attentions = all_attentions + (outputs[1],)
+
+ hidden_states = self.final_layer_norm(hidden_states)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ )
+
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, "BlockMask"],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+@auto_docstring(
+ custom_intro="""
+ GPTNeoXJapanese Model with a `language modeling` head on top for Classifier Model fine-tuning.
+ """
+)
+class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["embed_out.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.gpt_neox_japanese = GPTNeoXJapaneseModel(config)
+ self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.embed_out
+
+ def set_output_embeddings(self, new_embeddings):
+ self.embed_out = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Union[Cache, tuple[tuple[torch.FloatTensor]]]] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, GPTNeoXJapaneseForCausalLM, GPTNeoXJapaneseConfig
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("abeja/gpt-neox-japanese-2.7b")
+ >>> config = GPTNeoXJapaneseConfig.from_pretrained("abeja/gpt-neox-japanese-2.7b")
+ >>> config.is_decoder = True
+ >>> model = GPTNeoXJapaneseForCausalLM.from_pretrained("abeja/gpt-neox-japanese-2.7b", config=config)
+
+ >>> inputs = tokenizer("日本語のGPT-neoxがHugging Faceで使えます😀", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.logits
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.gpt_neox_japanese(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ lm_logits = self.embed_out(hidden_states)
+
+ lm_loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(lm_logits.device)
+
+ lm_loss = self.loss_function(
+ lm_logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "GPTNeoXJapaneseForCausalLM",
+ "GPTNeoXJapaneseLayer",
+ "GPTNeoXJapaneseModel",
+ "GPTNeoXJapanesePreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py
new file mode 100644
index 0000000000000000000000000000000000000000..584e74a8123e7cfdf31c4738a656a8417085e9a1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py
@@ -0,0 +1,369 @@
+# coding=utf-8
+# Copyright 2022 ABEJA, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for GPTNeoXJapanese."""
+
+import collections
+import json
+import os
+import re
+import sys
+from typing import Optional
+
+import numpy as np
+
+from ...tokenization_utils_fast import PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "emoji_file": "emoji.json"}
+
+
+def load_vocab_and_emoji(vocab_file, emoji_file):
+ """Loads a vocabulary file and emoji file into a dictionary."""
+ with open(emoji_file, "r", encoding="utf-8") as f:
+ emoji = json.loads(f.read())
+
+ vocab = collections.OrderedDict()
+ raw_vocab = collections.OrderedDict()
+ ids_to_tokens = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as f:
+ token = f.readlines()
+ token = [[t.rstrip("\n")] if (t == "," or "," not in t) else t.rstrip("\n").split(",") for t in token]
+ for idx, b in enumerate(token):
+ ids_to_tokens[idx] = b
+ raw_vocab[",".join(b)] = idx
+ for wd in b:
+ vocab[wd] = idx
+
+ return vocab, raw_vocab, ids_to_tokens, emoji
+
+
+class GPTNeoXJapaneseTokenizer(PreTrainedTokenizer):
+ """
+ This tokenizer inherits from [`PreTrainedTokenizer`] and is based on Japanese special Sub-Word-Encoding that is
+ used in this repository (https://github.com/tanreinama/Japanese-BPEEncoder_V2). Check the repository for details.
+ Japanese has a relatively large vocabulary and there is no separation between words. Furthermore, the language is a
+ combination of hiragana, katakana, and kanji, and variants such as "1" and "①" are often used. In order to cope
+ with these, this tokenizer has the following features
+ - Subword-by-subword segmentation, which is intermediate between byte strings and morphological analysis.
+ - BPEs are created for each Kanji, Hiragana, and Katakana character, and there are no BPEs that cross character
+ types, such as Kanji + Hiragana or Hiragana + Katakana.
+ - All-byte encoding that does not require .
+ - Independent of UTF codes such as 2-byte and 3-byte characters
+ - Conversion of heterographs to the same token_id
+ - Emoji and Emoticon are grouped into 12 types as special tags.
+
+ Example:
+
+ ```python
+ >>> from transformers import GPTNeoXJapaneseTokenizer
+
+ >>> tokenizer = GPTNeoXJapaneseTokenizer.from_pretrained("abeja/gpt-neox-japanese-2.7b")
+ >>> # You can confirm both 慶応 and 慶應 are encoded to 17749
+ >>> tokenizer("吾輩は猫である🐯。実は慶応(慶應)大学出身")["input_ids"]
+ [30014, 26883, 26638, 27228, 25, 26650, 31732, 31679, 27809, 26638, 17749, 31592, 17749, 31593, 321, 1281]
+
+ >>> # Both 慶応 and 慶應 are decoded to 慶応
+ >>> tokenizer.decode(tokenizer("吾輩は猫である🐯。実は慶応(慶應)大学出身")["input_ids"])
+ '吾輩は猫である🐯。実は慶応(慶応)大学出身'
+ ```
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ emoji_file (`str`):
+ File containing the emoji.
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The token used for padding
+ bos_token (`str`, *optional*, defaults to `"<|startoftext|>"`):
+ The beginning of sequence token.
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The end of sequence token.
+ do_clean_text (`bool`, *optional*, defaults to `False`):
+ Whether or not to clean text for URL, EMAIL, TEL, Japanese DATE and Japanese PRICE.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ emoji_file,
+ unk_token="<|endoftext|>",
+ pad_token="<|endoftext|>",
+ bos_token="<|startoftext|>",
+ eos_token="<|endoftext|>",
+ do_clean_text=False,
+ **kwargs,
+ ):
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = GPTNeoXJapaneseokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ if not os.path.isfile(emoji_file):
+ raise ValueError(
+ f"Can't find a emoji file at path '{emoji_file}'. To load the emoji information from a Google"
+ " pretrained model use `tokenizer = GPTNeoXJapaneseokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ self.do_clean_text = do_clean_text
+ self.vocab, self.raw_vocab, self.ids_to_tokens, self.emoji = load_vocab_and_emoji(vocab_file, emoji_file)
+ self.subword_tokenizer = SubWordJapaneseTokenizer(
+ vocab=self.vocab, ids_to_tokens=self.ids_to_tokens, emoji=self.emoji
+ )
+ super().__init__(
+ unk_token=unk_token,
+ pad_token=pad_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ do_clean_text=do_clean_text,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self):
+ # self.vocab contains support for character fluctuation unique to Japanese, and has a large number of vocab
+ return len(self.raw_vocab)
+
+ def get_vocab(self):
+ return dict(self.raw_vocab, **self.added_tokens_encoder)
+
+ def _tokenize(self, text):
+ return self.subword_tokenizer.tokenize(text, clean=self.do_clean_text)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.subword_tokenizer.convert_id_to_token(index)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ out_string = "".join(tokens).strip()
+ return out_string
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ index = 0
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ emoji_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["emoji_file"]
+ )
+ else:
+ vocab_file = (
+ (filename_prefix + "-" if filename_prefix else "") + save_directory + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ emoji_file = (
+ (filename_prefix + "-" if filename_prefix else "") + save_directory + VOCAB_FILES_NAMES["emoji_file"]
+ )
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token_index, token in self.ids_to_tokens.items():
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!"
+ )
+ index = token_index
+ writer.write(",".join(token) + "\n")
+ index += 1
+ with open(emoji_file, "w", encoding="utf-8") as writer:
+ json.dump(self.emoji, writer)
+ return vocab_file, emoji_file
+
+
+class SubWordJapaneseTokenizer:
+ """
+ https://github.com/tanreinama/Japanese-BPEEncoder_V2 This tokenizer class is under MIT License according to the
+ original repository.
+
+ MIT License
+
+ Copyright (c) 2020 tanreinama
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
+ documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
+ rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
+ permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of
+ the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
+ THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+ """
+
+ def __init__(self, vocab, ids_to_tokens, emoji):
+ self.vocab = vocab # same as swe
+ self.ids_to_tokens = ids_to_tokens # same as bpe
+ self.emoji = emoji
+ self.maxlen = np.max([len(w) for w in self.vocab])
+ self.content_repatter1 = re.compile(r"(https?|ftp)(:\/\/[-_\.!~*\'()a-zA-Z0-9;\/?:\@&=\+$,%#]+)")
+ self.content_repatter2 = re.compile(r"[A-Za-z0-9\._+]*@[\-_0-9A-Za-z]+(\.[A-Za-z]+)*")
+ self.content_repatter3 = re.compile(r"[\(]{0,1}[0-9]{2,4}[\)\-\(]{0,1}[0-9]{2,4}[\)\-]{0,1}[0-9]{3,4}")
+ self.content_repatter4 = re.compile(
+ r"([12]\d{3}[/\-年])*(0?[1-9]|1[0-2])[/\-月]((0?[1-9]|[12][0-9]|3[01])日?)*(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*"
+ )
+ self.content_repatter5 = re.compile(
+ r"(明治|大正|昭和|平成|令和|㍾|㍽|㍼|㍻|\u32ff)\d{1,2}年(0?[1-9]|1[0-2])月(0?[1-9]|[12][0-9]|3[01])日(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*"
+ )
+ # The original version of this regex displays catastrophic backtracking behaviour. We avoid this using
+ # possessive quantifiers in Py >= 3.11. In versions below this, we avoid the vulnerability using a slightly
+ # different regex that should generally have the same behaviour in most non-pathological cases.
+ if sys.version_info >= (3, 11):
+ self.content_repatter6 = re.compile(
+ r"(?:\d,\d{3}|[\d億])*+"
+ r"(?:\d,\d{3}|[\d万])*+"
+ r"(?:\d,\d{3}|[\d千])*+"
+ r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
+ r"(?:\(税込\)|\(税抜\)|\+tax)*"
+ )
+ else:
+ self.content_repatter6 = re.compile(
+ r"(?:\d,\d{3}|[\d億万千])*"
+ r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
+ r"(?:\(税込\)|\(税抜\)|\+tax)*"
+ )
+ keisen = "─━│┃┄┅┆┇┈┉┊┋┌┍┎┏┐┑┒┓└┕┖┗┘┙┚┛├┝┞┟┠┡┢┣┤┥┦┧┨┩┪┫┬┭┮┯┰┱┲┳┴┵┶┷┸┹┺┻┼┽┾┿╀╁╂╃╄╅╆╇╈╉╊╋╌╍╎╏═║╒╓╔╕╖╗╘╙╚╛╜╝╞╟╠╡╢╣╤╥╦╧╨╩╪╫╬╭╮╯╰╱╲╳╴╵╶╷╸╹╺╻╼╽╾╿"
+ blocks = "▀▁▂▃▄▅▆▇█▉▊▋▌▍▎▏▐░▒▓▔▕▖▗▘▙▚▛▜▝▞▟"
+ self.content_trans1 = str.maketrans(dict.fromkeys(keisen + blocks, ""))
+
+ def __len__(self):
+ return len(self.ids_to_tokens)
+
+ def clean_text(self, content):
+ content = self.content_repatter1.sub("", content)
+ content = self.content_repatter2.sub("", content)
+ content = self.content_repatter3.sub("", content)
+ content = self.content_repatter4.sub("", content)
+ content = self.content_repatter5.sub("", content)
+ content = self.content_repatter6.sub("", content)
+ content = content.translate(self.content_trans1)
+ while "" in content:
+ content = content.replace("", "")
+ return content
+
+ def tokenize(self, text, clean=False):
+ text = text.replace(" ", "")
+ text = text.replace(" ", "")
+ text = text.replace("\r\n", "
")
+ text = text.replace("\n", "
")
+ text = text.replace("\r", "
")
+ text = text.replace("\t", "")
+ text = text.replace("—", "ー")
+ text = text.replace("−", "ー")
+ for k, v in self.emoji["emoji"].items():
+ if k in text:
+ text = text.replace(k, v)
+ if clean:
+ text = self.clean_text(text)
+
+ def check_simbol(x):
+ e = x.encode()
+ if len(x) == 1 and len(e) == 2:
+ c = (int(e[0]) << 8) + int(e[1])
+ if (
+ (c >= 0xC2A1 and c <= 0xC2BF)
+ or (c >= 0xC780 and c <= 0xC783)
+ or (c >= 0xCAB9 and c <= 0xCBBF)
+ or (c >= 0xCC80 and c <= 0xCDA2)
+ ):
+ return True
+ return False
+
+ def checku2e(x):
+ e = x.encode()
+ if len(x) == 1 and len(e) == 3:
+ c = (int(e[0]) << 16) + (int(e[1]) << 8) + int(e[2])
+ if c >= 0xE28080 and c <= 0xE2B07F:
+ return True
+ return False
+
+ pos = 0
+ result = []
+ while pos < len(text):
+ end = min(len(text), pos + self.maxlen + 1) if text[pos] == "<" else pos + 3
+ candidates = [] # (token_id, token, pos)
+ for e in range(end, pos, -1):
+ wd = text[pos:e]
+ if wd in self.vocab:
+ if wd[0] == "<" and len(wd) > 2:
+ candidates = [(self.vocab[wd], wd, e)]
+ break
+ else:
+ candidates.append((self.vocab[wd], wd, e))
+ if len(candidates) > 0:
+ # the smallest token_id is adopted
+ _, wd, e = min(candidates, key=lambda x: x[0])
+ result.append(wd)
+ pos = e
+ else:
+ end = pos + 1
+ wd = text[pos:end]
+ if check_simbol(wd):
+ result.append("")
+ elif checku2e(wd):
+ result.append("")
+ else:
+ for i in wd.encode("utf-8"):
+ result.append("<|byte%d|>" % i)
+ pos = end
+ return result
+
+ def convert_id_to_token(self, index, breakline="\n"):
+ words = []
+ byte_tokens = []
+ word = self.ids_to_tokens[index][0]
+ if word[:6] == "<|byte" and word[-2:] == "|>":
+ byte_tokens.append(int(word[6:-2]))
+ else:
+ if len(byte_tokens) > 0:
+ words.append(bytearray(byte_tokens).decode("utf-8", errors="replace"))
+ byte_tokens = []
+ if word[:7] == "<|emoji" and word[-2:] == "|>":
+ words.append(self.emoji["emoji_inv"][word])
+ elif word == "":
+ words.append(" ")
+ elif word == "
":
+ words.append(breakline)
+ elif word == "":
+ words.append("\t")
+ elif word == "":
+ words.append("▀")
+ elif word == "":
+ words.append("ǀ")
+ elif word == "":
+ words.append("‖")
+ else:
+ words.append(word)
+ if len(byte_tokens) > 0:
+ words.append(bytearray(byte_tokens).decode("utf-8", errors="replace"))
+ text = "".join(words)
+ return text
+
+
+__all__ = ["GPTNeoXJapaneseTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_oss/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_oss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..19e12e75ef8f46a09244a1a0541bec8097ba3a94
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_oss/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_gpt_oss import *
+ from .modeling_gpt_oss import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_oss/configuration_gpt_oss.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_oss/configuration_gpt_oss.py
new file mode 100644
index 0000000000000000000000000000000000000000..6459e9a7fd4a936033f06b95fd7fc4c1381fe31b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_oss/configuration_gpt_oss.py
@@ -0,0 +1,126 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""openai model configuration"""
+
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...modeling_rope_utils import rope_config_validation
+
+
+class GptOssConfig(PretrainedConfig):
+ r"""
+ This will yield a configuration to that of the BERT
+ [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture.
+
+ """
+
+ model_type = "gpt_oss"
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.self_attn.sinks": "local_rowwise",
+ "layers.*.mlp.experts": "gather",
+ "layers.*.mlp.router": "ep_router",
+ "layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
+ "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm",
+ "layers.*.mlp.experts.down_proj": "grouped_gemm",
+ "layers.*.mlp.experts.down_proj_bias": "grouped_gemm",
+ }
+
+ def __init__(
+ self,
+ num_hidden_layers: int = 36,
+ num_local_experts: int = 128,
+ vocab_size: int = 201088,
+ hidden_size: int = 2880,
+ intermediate_size: int = 2880,
+ head_dim: int = 64,
+ num_attention_heads: int = 64,
+ num_key_value_heads: int = 8,
+ sliding_window: int = 128,
+ rope_theta: float = 150000.0,
+ tie_word_embeddings=False,
+ hidden_act: str = "silu",
+ initializer_range: float = 0.02,
+ max_position_embeddings=131072,
+ rms_norm_eps: float = 1e-5,
+ rope_scaling={
+ "rope_type": "yarn",
+ "factor": 32.0,
+ "beta_fast": 32.0,
+ "beta_slow": 1.0,
+ "truncate": False,
+ "original_max_position_embeddings": 4096,
+ },
+ attention_dropout: float = 0.0,
+ num_experts_per_tok=4,
+ router_aux_loss_coef: float = 0.9,
+ output_router_logits=False,
+ use_cache=True,
+ layer_types=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_local_experts = num_local_experts
+ self.sliding_window = sliding_window
+ self.num_experts_per_tok = num_experts_per_tok
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_dropout = attention_dropout
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+ self.layer_types = layer_types
+ if self.layer_types is None:
+ self.layer_types = [
+ "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+ self.attention_bias = True
+ self.max_position_embeddings = max_position_embeddings
+ self.router_aux_loss_coef = router_aux_loss_coef
+ self.output_router_logits = output_router_logits
+ self.use_cache = use_cache
+
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+__all__ = ["GptOssConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d5c936e8adca36cf69054ec7d7952ccb0125f75
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py
@@ -0,0 +1,725 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/gpt_oss/modular_gpt_oss.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_gpt_oss.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations.hub_kernels import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_layers import (
+ GenericForSequenceClassification,
+ GenericForTokenClassification,
+ GradientCheckpointingLayer,
+)
+from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import OutputRecorder, check_model_inputs
+from .configuration_gpt_oss import GptOssConfig
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class GptOssRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ GptOssRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return (self.weight * hidden_states).to(input_dtype) # main diff with Llama
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class GptOssExperts(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.intermediate_size = config.intermediate_size
+ self.num_experts = config.num_local_experts
+ self.hidden_size = config.hidden_size
+ self.expert_dim = self.intermediate_size
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
+ self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
+ self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
+ self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
+ self.alpha = 1.702
+ self.limit = 7.0
+
+ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
+ """
+ When training it is more efficient to just loop over the experts and compute the output for each expert
+ as otherwise the memory would explode.
+
+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
+
+ Args:
+ hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
+ selected_experts (torch.Tensor): (batch_size * token_num, top_k)
+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
+ Returns:
+ torch.Tensor
+ """
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
+ num_experts = routing_weights.shape[1]
+ if hidden_states.device.type == "cpu" or self.training:
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+ with torch.no_grad():
+ expert_mask = torch.nn.functional.one_hot(
+ router_indices, num_classes=num_experts + 1
+ ) # masking is also a class
+ expert_mask = expert_mask.permute(2, 1, 0)
+ # we sum on the top_k and on the sequence length to get which experts
+ # are hit this time around
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+ for expert_idx in expert_hit[:]:
+ # expert_idx only have 1 element, so we can use scale for fast indexing
+ expert_idx = expert_idx[0]
+ # skip masking index
+ if expert_idx == num_experts:
+ continue
+ with torch.no_grad():
+ _, token_idx = torch.where(expert_mask[expert_idx])
+ current_state = hidden_states[token_idx]
+ gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=self.limit)
+ up = up.clamp(min=-self.limit, max=self.limit)
+ glu = gate * torch.sigmoid(gate * self.alpha)
+ gated_output = (up + 1) * glu
+ out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
+ weighted_output = out * routing_weights[token_idx, expert_idx, None]
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
+ else:
+ hidden_states = hidden_states.repeat(num_experts, 1)
+ hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
+ gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=self.limit)
+ up = up.clamp(min=-self.limit, max=self.limit)
+ glu = gate * torch.sigmoid(gate * self.alpha)
+ next_states = torch.bmm(((up + 1) * glu), self.down_proj)
+ next_states = next_states + self.down_proj_bias[..., None, :]
+ next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
+ next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
+ next_states = next_states.sum(dim=0)
+ return next_states
+
+
+class GptOssTopKRouter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.top_k = config.num_experts_per_tok
+ self.num_experts = config.num_local_experts
+ self.hidden_dim = config.hidden_size
+ self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
+ self.bias = nn.Parameter(torch.empty(self.num_experts))
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+ router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
+ router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+ router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+ return router_scores, router_indices
+
+
+@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
+class GptOssMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.router = GptOssTopKRouter(config)
+ self.experts = GptOssExperts(config)
+
+ def forward(self, hidden_states):
+ router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len)
+ routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
+ return routed_out, router_scores
+
+
+class GptOssRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: GptOssConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = freqs
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(x.dtype), sin.to(x.dtype)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def _apply_rotary_emb(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+) -> torch.Tensor:
+ first_half, second_half = torch.chunk(x, 2, dim=-1)
+ first_ = first_half * cos - second_half * sin
+ second_ = second_half * cos + first_half * sin
+ return torch.cat((first_, second_), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = _apply_rotary_emb(q, cos, sin)
+ k_embed = _apply_rotary_emb(k, cos, sin)
+ return q_embed, k_embed
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
+ combined_logits = torch.cat([attn_weights, sinks], dim=-1)
+
+ # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16
+ # when training with bsz>1 we clamp max values.
+
+ combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
+ probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
+ scores = probs[..., :-1] # we drop the sink here
+ attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, attn_weights
+
+
+class GptOssAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: GptOssConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+ self.sinks = nn.Parameter(torch.empty(config.num_attention_heads))
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ cache_kwargs = {"cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ s_aux=self.sinks, # diff with Llama
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class GptOssDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: GptOssConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = GptOssAttention(config=config, layer_idx=layer_idx)
+ self.mlp = GptOssMLP(config)
+ self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.attention_type = config.layer_types[layer_idx]
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class GptOssPreTrainedModel(PreTrainedModel):
+ config: GptOssConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GptOssDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = False
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "router_logits": OutputRecorder(GptOssTopKRouter, index=0),
+ "hidden_states": GptOssDecoderLayer,
+ "attentions": GptOssAttention,
+ }
+ _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
+ _supports_flash_attention = False
+ _supports_flex_attention = False
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Parameter):
+ module.data.normal_(mean=0.0, std=std)
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, GptOssRMSNorm):
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, GptOssExperts):
+ module.gate_up_proj.data.normal_(mean=0.0, std=std)
+ module.gate_up_proj_bias.data.zero_()
+ module.down_proj.data.normal_(mean=0.0, std=std)
+ module.down_proj_bias.data.zero_()
+ elif isinstance(module, GptOssAttention):
+ module.sinks.data.normal_(mean=0.0, std=std)
+ elif isinstance(module, GptOssTopKRouter):
+ module.weight.data.normal_(mean=0.0, std=std)
+ module.bias.data.normal_(mean=0.0, std=std)
+
+
+@auto_docstring
+class GptOssModel(GptOssPreTrainedModel):
+ _no_split_modules = ["GptOssDecoderLayer"]
+
+ def __init__(self, config: GptOssConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [GptOssDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = GptOssRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ }
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = self.norm(hidden_states)
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+def load_balancing_loss_func(
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
+ num_experts: Optional[int] = None,
+ top_k=2,
+ attention_mask: Optional[torch.Tensor] = None,
+) -> Union[torch.Tensor, int]:
+ r"""
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+
+ Args:
+ gate_logits:
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+ shape [batch_size X sequence_length, num_experts].
+ num_experts:
+ Number of experts
+ top_k:
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
+ parameter.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention_mask used in forward function
+ shape [batch_size X sequence_length] if not None.
+
+ Returns:
+ The auxiliary loss.
+ """
+ if gate_logits is None or not isinstance(gate_logits, tuple):
+ return 0
+
+ if isinstance(gate_logits, tuple):
+ compute_device = gate_logits[0].device
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
+
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
+
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+ if attention_mask is None:
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
+ else:
+ batch_size, sequence_length = attention_mask.shape
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+ expert_attention_mask = (
+ attention_mask[None, :, :, None, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+ .reshape(-1, top_k, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+ expert_attention_mask, dim=0
+ )
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+ router_per_expert_attention_mask = (
+ attention_mask[None, :, :, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
+ .reshape(-1, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+ router_per_expert_attention_mask, dim=0
+ )
+
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ return overall_loss * num_experts
+
+
+@auto_docstring
+class GptOssForCausalLM(GptOssPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = GptOssModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.router_aux_loss_coef = config.router_aux_loss_coef
+ self.num_experts = config.num_local_experts
+ self.num_experts_per_tok = config.num_experts_per_tok
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, GptOssForCausalLM
+
+ >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1")
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: MoeModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_router_logits=output_router_logits,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits,
+ self.num_experts,
+ self.num_experts_per_tok,
+ attention_mask,
+ )
+ if labels is not None:
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
+
+
+class GptOssForSequenceClassification(GenericForSequenceClassification, GptOssPreTrainedModel):
+ pass
+
+
+class GptOssForTokenClassification(GenericForTokenClassification, GptOssPreTrainedModel):
+ pass
+
+
+__all__ = [
+ "GptOssForCausalLM",
+ "GptOssForSequenceClassification",
+ "GptOssForTokenClassification",
+ "GptOssModel",
+ "GptOssPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_oss/modular_gpt_oss.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_oss/modular_gpt_oss.py
new file mode 100644
index 0000000000000000000000000000000000000000..aba879af9336d882e742e4bbbf963bee24cef89b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/gpt_oss/modular_gpt_oss.py
@@ -0,0 +1,472 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ...cache_utils import Cache, DynamicCache
+from ...integrations.hub_kernels import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_outputs import (
+ MoeModelOutputWithPast,
+)
+from ...modeling_rope_utils import dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import (
+ TransformersKwargs,
+ auto_docstring,
+ logging,
+)
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import OutputRecorder, check_model_inputs
+from ..llama.modeling_llama import (
+ LlamaDecoderLayer,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ repeat_kv,
+)
+from ..mixtral.modeling_mixtral import (
+ MixtralForCausalLM,
+ MixtralForSequenceClassification,
+ MixtralForTokenClassification,
+ MixtralModel,
+)
+from ..qwen2.modeling_qwen2 import Qwen2Attention
+from .configuration_gpt_oss import GptOssConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class GptOssRMSNorm(LlamaRMSNorm):
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return (self.weight * hidden_states).to(input_dtype) # main diff with Llama
+
+
+class GptOssExperts(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.intermediate_size = config.intermediate_size
+ self.num_experts = config.num_local_experts
+ self.hidden_size = config.hidden_size
+ self.expert_dim = self.intermediate_size
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
+ self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
+ self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
+ self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
+ self.alpha = 1.702
+ self.limit = 7.0
+
+ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
+ """
+ When training it is more efficient to just loop over the experts and compute the output for each expert
+ as otherwise the memory would explode.
+
+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
+
+ Args:
+ hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
+ selected_experts (torch.Tensor): (batch_size * token_num, top_k)
+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
+ Returns:
+ torch.Tensor
+ """
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
+ num_experts = routing_weights.shape[1]
+ if hidden_states.device.type == "cpu" or self.training:
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+ with torch.no_grad():
+ expert_mask = torch.nn.functional.one_hot(
+ router_indices, num_classes=num_experts + 1
+ ) # masking is also a class
+ expert_mask = expert_mask.permute(2, 1, 0)
+ # we sum on the top_k and on the sequence length to get which experts
+ # are hit this time around
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+ for expert_idx in expert_hit[:]:
+ # expert_idx only have 1 element, so we can use scale for fast indexing
+ expert_idx = expert_idx[0]
+ # skip masking index
+ if expert_idx == num_experts:
+ continue
+ with torch.no_grad():
+ _, token_idx = torch.where(expert_mask[expert_idx])
+ current_state = hidden_states[token_idx]
+ gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=self.limit)
+ up = up.clamp(min=-self.limit, max=self.limit)
+ glu = gate * torch.sigmoid(gate * self.alpha)
+ gated_output = (up + 1) * glu
+ out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
+ weighted_output = out * routing_weights[token_idx, expert_idx, None]
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
+ else:
+ hidden_states = hidden_states.repeat(num_experts, 1)
+ hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
+ gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=self.limit)
+ up = up.clamp(min=-self.limit, max=self.limit)
+ glu = gate * torch.sigmoid(gate * self.alpha)
+ next_states = torch.bmm(((up + 1) * glu), self.down_proj)
+ next_states = next_states + self.down_proj_bias[..., None, :]
+ next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
+ next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
+ next_states = next_states.sum(dim=0)
+ return next_states
+
+
+class GptOssTopKRouter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.top_k = config.num_experts_per_tok
+ self.num_experts = config.num_local_experts
+ self.hidden_dim = config.hidden_size
+ self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
+ self.bias = nn.Parameter(torch.empty(self.num_experts))
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+ router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
+ router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+ router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+ return router_scores, router_indices
+
+
+@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
+class GptOssMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.router = GptOssTopKRouter(config)
+ self.experts = GptOssExperts(config)
+
+ def forward(self, hidden_states):
+ router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len)
+ routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
+ return routed_out, router_scores
+
+
+class GptOssRotaryEmbedding(LlamaRotaryEmbedding):
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = freqs
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(x.dtype), sin.to(x.dtype)
+
+
+def _apply_rotary_emb(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+) -> torch.Tensor:
+ first_half, second_half = torch.chunk(x, 2, dim=-1)
+ first_ = first_half * cos - second_half * sin
+ second_ = second_half * cos + first_half * sin
+ return torch.cat((first_, second_), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = _apply_rotary_emb(q, cos, sin)
+ k_embed = _apply_rotary_emb(k, cos, sin)
+ return q_embed, k_embed
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
+ combined_logits = torch.cat([attn_weights, sinks], dim=-1)
+
+ # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16
+ # when training with bsz>1 we clamp max values.
+
+ combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
+ probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
+ scores = probs[..., :-1] # we drop the sink here
+ attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, attn_weights
+
+
+class GptOssAttention(Qwen2Attention):
+ def __init__(self, config: GptOssConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.sinks = nn.Parameter(torch.empty(config.num_attention_heads))
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ cache_kwargs = {"cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ s_aux=self.sinks, # diff with Llama
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class GptOssDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: GptOssConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.hidden_size = config.hidden_size
+ self.self_attn = GptOssAttention(config=config, layer_idx=layer_idx)
+ self.mlp = GptOssMLP(config)
+ self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.attention_type = config.layer_types[layer_idx]
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class GptOssPreTrainedModel(LlamaPreTrainedModel):
+ _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
+ _supports_sdpa = False
+ _supports_flash_attention = False
+ _supports_flex_attention = False
+ _can_record_outputs = {
+ "router_logits": OutputRecorder(GptOssTopKRouter, index=0),
+ "hidden_states": GptOssDecoderLayer,
+ "attentions": GptOssAttention,
+ }
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Parameter):
+ module.data.normal_(mean=0.0, std=std)
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, GptOssRMSNorm):
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, GptOssExperts):
+ module.gate_up_proj.data.normal_(mean=0.0, std=std)
+ module.gate_up_proj_bias.data.zero_()
+ module.down_proj.data.normal_(mean=0.0, std=std)
+ module.down_proj_bias.data.zero_()
+ elif isinstance(module, GptOssAttention):
+ module.sinks.data.normal_(mean=0.0, std=std)
+ elif isinstance(module, GptOssTopKRouter):
+ module.weight.data.normal_(mean=0.0, std=std)
+ module.bias.data.normal_(mean=0.0, std=std)
+
+
+class GptOssModel(MixtralModel):
+ _no_split_modules = ["GptOssDecoderLayer"]
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ }
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = self.norm(hidden_states)
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+class GptOssForCausalLM(MixtralForCausalLM):
+ pass
+
+
+class GptOssForSequenceClassification(MixtralForSequenceClassification):
+ pass
+
+
+class GptOssForTokenClassification(MixtralForTokenClassification):
+ pass
+
+
+__all__ = [
+ "GptOssForCausalLM",
+ "GptOssForSequenceClassification",
+ "GptOssForTokenClassification",
+ "GptOssModel",
+ "GptOssPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/granite/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/granite/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..08b74d14ca871a7752b61fb71f08b7a7886d80f7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/granite/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_granite import *
+ from .modeling_granite import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/granite/configuration_granite.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/granite/configuration_granite.py
new file mode 100644
index 0000000000000000000000000000000000000000..61d3ba9e7bb2775e537608d277e5973ec42a8cf9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/granite/configuration_granite.py
@@ -0,0 +1,197 @@
+# coding=utf-8
+# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Granite model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class GraniteConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GraniteModel`]. It is used to instantiate an Granite
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Granite-3B.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the Granite model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`GraniteModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+ embedding_multiplier (`float`, *optional*, defaults to 1.0): embedding multiplier
+ logits_scaling (`float`, *optional*, defaults to 1.0): divisor for output logits
+ residual_multiplier (`float`, *optional*, defaults to 1.0): residual multiplier
+ attention_multiplier (`float`, *optional*, defaults to 1.0): attention multiplier
+
+ ```python
+ >>> from transformers import GraniteModel, GraniteConfig
+
+ >>> # Initializing a Granite granite-3b style configuration
+ >>> configuration = GraniteConfig()
+
+ >>> # Initializing a model from the granite-7b style configuration
+ >>> model = GraniteModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "granite"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `GraniteModel`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ mlp_bias=False,
+ embedding_multiplier=1.0,
+ logits_scaling=1.0,
+ residual_multiplier=1.0,
+ attention_multiplier=1.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.mlp_bias = mlp_bias
+
+ self.embedding_multiplier = embedding_multiplier
+ self.logits_scaling = logits_scaling
+ self.residual_multiplier = residual_multiplier
+ self.attention_multiplier = attention_multiplier
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ rope_config_validation(self)
+
+
+__all__ = ["GraniteConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/granite/modeling_granite.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/granite/modeling_granite.py
new file mode 100644
index 0000000000000000000000000000000000000000..846865c55508e223bf6b512ba03b2b64bd0e2434
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/granite/modeling_granite.py
@@ -0,0 +1,565 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/granite/modular_granite.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_granite.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_granite import GraniteConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class GraniteAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = config.attention_multiplier
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class GraniteRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ GraniteRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class GraniteMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class GraniteDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: GraniteConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = GraniteMLP(config)
+ self.input_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.residual_multiplier = config.residual_multiplier
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states * self.residual_multiplier
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class GranitePreTrainedModel(PreTrainedModel):
+ config: GraniteConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GraniteDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": GraniteDecoderLayer,
+ "attentions": GraniteAttention,
+ }
+
+
+class GraniteRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: GraniteConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@auto_docstring
+class GraniteModel(GranitePreTrainedModel):
+ def __init__(self, config: GraniteConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [GraniteDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = GraniteRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+ self.embedding_multiplier = config.embedding_multiplier
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+@auto_docstring
+class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = GraniteModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, GraniteForCausalLM
+
+ >>> model = GraniteForCausalLM.from_pretrained("meta-granite/Granite-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-granite/Granite-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ logits = logits / self.config.logits_scaling # main diff with Llama
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["GraniteForCausalLM", "GraniteModel", "GranitePreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/granite/modular_granite.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/granite/modular_granite.py
new file mode 100644
index 0000000000000000000000000000000000000000..37e1955fcb0984ea62c9d37e73879667aab1ec1b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/granite/modular_granite.py
@@ -0,0 +1,285 @@
+# coding=utf-8
+# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...cache_utils import Cache, DynamicCache
+from ...masking_utils import create_causal_mask
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, logging
+from ...utils.deprecation import deprecate_kwarg
+from ..llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaForCausalLM,
+ LlamaModel,
+ LlamaPreTrainedModel,
+)
+from .configuration_granite import GraniteConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class GraniteAttention(LlamaAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+ self.scaling = config.attention_multiplier
+
+
+class GraniteDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: GraniteConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.residual_multiplier = config.residual_multiplier
+ self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states * self.residual_multiplier
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+class GranitePreTrainedModel(LlamaPreTrainedModel):
+ pass
+
+
+class GraniteModel(LlamaModel):
+ def __init__(self, config: GraniteConfig):
+ super().__init__(config)
+ self.embedding_multiplier = config.embedding_multiplier
+ self.layers = nn.ModuleList(
+ [GraniteDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class GraniteForCausalLM(LlamaForCausalLM):
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ logits = logits / self.config.logits_scaling # main diff with Llama
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["GraniteForCausalLM", "GraniteModel", "GranitePreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/helium/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/helium/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73d0966e5c1c489944c4b539c0cc06384c985c87
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/helium/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_helium import *
+ from .modeling_helium import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/helium/configuration_helium.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/helium/configuration_helium.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bb4d8d88750bc5e7becd8cd20a2ef8db1d5936b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/helium/configuration_helium.py
@@ -0,0 +1,154 @@
+# coding=utf-8
+# Copyright 2024 The Kyutai and HuggingFace Inc. teams. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+
+
+class HeliumConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`HeliumModel`]. It is used to instantiate an Helium
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Helium 2b model.
+ e.g. [kyutai/helium-2b](https://huggingface.co/kyutai/helium-2b)
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ vocab_size (`int`, *optional*, defaults to 48000):
+ Vocabulary size of the Helium model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`HeliumModel`]
+ hidden_size (`int`, *optional*, defaults to 2560):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 7040):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 24):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 20):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 20):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ head_dim (`int`, *optional*, defaults to 128):
+ The attention head dimension.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The legacy activation function. It is overwritten by the `hidden_activation`.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-08):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 100000.0):
+ The base period of the RoPE embeddings.
+ pad_token_id (`int`, *optional*, defaults to 3):
+ Padding token id.
+ eos_token_id (`int` | `list`, *optional*, defaults to 2):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+ ```python
+ >>> from transformers import HeliumModel, HeliumConfig
+ >>> # Initializing a Helium 2b style configuration
+ >>> configuration = HeliumConfig()
+ >>> # Initializing a model from the Helium 2b style configuration
+ >>> model = HeliumModel(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "helium"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=48000,
+ hidden_size=2560,
+ intermediate_size=7040,
+ num_hidden_layers=24,
+ num_attention_heads=20,
+ num_key_value_heads=20,
+ head_dim=128,
+ hidden_act="silu",
+ attention_dropout=0.0,
+ max_position_embeddings=4096,
+ initializer_range=0.02,
+ rms_norm_eps=1e-8,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=100000.0,
+ pad_token_id=3,
+ eos_token_id=2,
+ bos_token_id=1,
+ attention_bias=False,
+ mlp_bias=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.head_dim = head_dim
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.mlp_bias = mlp_bias
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+__all__ = ["HeliumConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/helium/modeling_helium.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/helium/modeling_helium.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f4a2e73affd76f079ecfafb2b51ba88e0ccf252
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/helium/modeling_helium.py
@@ -0,0 +1,497 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/helium/modular_helium.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_helium.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 The Kyutai and HuggingFace Inc. teams. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask
+from ...modeling_layers import (
+ GenericForSequenceClassification,
+ GenericForTokenClassification,
+ GradientCheckpointingLayer,
+)
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_helium import HeliumConfig
+
+
+class HeliumRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class HeliumRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: HeliumConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class HeliumMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., 0::2]
+ x2 = x[..., 1::2]
+ return torch.stack((-x2, x1), dim=-1).flatten(-2)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ # Interleave them instead of usual shape
+ cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
+ sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+
+ return q_embed, k_embed
+
+
+class HeliumAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = 1 / math.sqrt(self.head_dim)
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class HeliumDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = HeliumAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = HeliumMLP(config)
+ self.input_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class HeliumPreTrainedModel(PreTrainedModel):
+ config: HeliumConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["HeliumDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": HeliumDecoderLayer,
+ "attentions": HeliumAttention,
+ }
+
+
+@auto_docstring
+class HeliumModel(HeliumPreTrainedModel):
+ def __init__(self, config: HeliumConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [HeliumDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = HeliumRotaryEmbedding(config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = HeliumModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, HeliumForCausalLM
+
+ >>> model = HeliumForCausalLM.from_pretrained("google/helium-7b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/helium-7b")
+
+ >>> prompt = "What is your favorite condiment?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "What is your favorite condiment?"
+ ```"""
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class HeliumForSequenceClassification(GenericForSequenceClassification, HeliumPreTrainedModel):
+ pass
+
+
+class HeliumForTokenClassification(GenericForTokenClassification, HeliumPreTrainedModel):
+ pass
+
+
+__all__ = [
+ "HeliumPreTrainedModel",
+ "HeliumModel",
+ "HeliumForCausalLM",
+ "HeliumForSequenceClassification",
+ "HeliumForTokenClassification",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/helium/modular_helium.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/helium/modular_helium.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c2538d438f9126980e7347aa98896b52a77f6eb
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/helium/modular_helium.py
@@ -0,0 +1,149 @@
+# coding=utf-8
+# Copyright 2024 The Kyutai and HuggingFace Inc. teams. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from ...utils import logging
+from ..gemma.modeling_gemma import GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification
+from ..granite.modeling_granite import GraniteAttention
+from ..llama.modeling_llama import LlamaDecoderLayer, LlamaMLP, LlamaModel, LlamaPreTrainedModel, LlamaRotaryEmbedding
+from .configuration_helium import HeliumConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class HeliumRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class HeliumRotaryEmbedding(LlamaRotaryEmbedding):
+ pass
+
+
+class HeliumMLP(LlamaMLP):
+ pass
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., 0::2]
+ x2 = x[..., 1::2]
+ return torch.stack((-x2, x1), dim=-1).flatten(-2)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ # Interleave them instead of usual shape
+ cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
+ sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+
+ return q_embed, k_embed
+
+
+class HeliumAttention(GraniteAttention):
+ def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ self.scaling = 1 / math.sqrt(self.head_dim)
+
+
+class HeliumDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+
+ self.mlp = HeliumMLP(config)
+ self.input_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+
+class HeliumPreTrainedModel(LlamaPreTrainedModel):
+ pass
+
+
+class HeliumModel(HeliumPreTrainedModel, LlamaModel):
+ def __init__(self, config: HeliumConfig):
+ super().__init__(config)
+ self.layers = nn.ModuleList(
+ [HeliumDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = HeliumRotaryEmbedding(config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+
+class HeliumForCausalLM(GemmaForCausalLM):
+ pass
+
+
+class HeliumForSequenceClassification(GemmaForSequenceClassification):
+ pass
+
+
+class HeliumForTokenClassification(GemmaForTokenClassification):
+ pass
+
+
+__all__ = [
+ "HeliumPreTrainedModel",
+ "HeliumModel",
+ "HeliumForCausalLM",
+ "HeliumForSequenceClassification",
+ "HeliumForTokenClassification",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/hiera/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/hiera/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..841f13be4c0d2f48f54eecc916acd826395449af
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/hiera/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_hiera import *
+ from .modeling_hiera import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/hiera/configuration_hiera.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/hiera/configuration_hiera.py
new file mode 100644
index 0000000000000000000000000000000000000000..2342d7e562a50de0c0937040a8e8279c7860e931
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/hiera/configuration_hiera.py
@@ -0,0 +1,194 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Hiera model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class HieraConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`HieraModel`]. It is used to instantiate a Hiera
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Hiera
+ [facebook/hiera-base-224](https://huggingface.co/facebook/hiera-base-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ embed_dim (`int`, *optional*, defaults to 96):
+ Dimensionality of patch embedding.
+ image_size (`list(int)`, *optional*, defaults to `[224, 224]`):
+ The size (resolution) of input in the format (height, width) for images
+ and (frames, height, width) for videos.
+ patch_size (`list(int)`, *optional*, defaults to `[7, 7]`):
+ The size (resolution) of each patch.
+ patch_stride (`list(int)`, *optional*, defaults to `[4, 4]`):
+ The stride of the patch.
+ patch_padding (`list(int)`, *optional*, defaults to `[3, 3]`):
+ The padding of the patch.
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ The ratio of mlp hidden dim to embedding dim.
+ depths (`list(int)`, *optional*, defaults to `[2, 3, 16, 3]`):
+ Depth of each layer in the Transformer encoder.
+ num_heads (`list(int)`, *optional*, defaults to `[1, 2, 4, 8]`):
+ Number of attention heads in each layer of the Transformer encoder.
+ embed_dim_multiplier (`float`, *optional*, defaults to 2.0):
+ The multiplier to the dimensionality of patch embedding in each layer of the Transformer encoder.
+ num_query_pool (`int`, *optional*, defaults to 3):
+ The number of query pool stages.
+ query_stride (`list(int)`, *optional*, defaults to `[2, 2]`):
+ The stride of the query pool.
+ masked_unit_size (`list(int)`, *optional*, defaults to `[8, 8]`):
+ The size of the masked unit.
+ masked_unit_attention (`list(bool)`, *optional*, defaults to `[True, True, False, False]`):
+ Whether to use masked unit attention in each layer of the Transformer encoder.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ The drop path rate.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
+ `"selu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices and
+ the zero_initializer for initializing all bias vectors.
+ layer_norm_init (`float`, *optional*, defaults to 1.0):
+ The initial weight value for layer normalization layers.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ decoder_hidden_size (`int`, *optional*):
+ Dimensionality of decoder embeddings for MAE pretraining.
+ decoder_depth (`int`, *optional*):
+ Depth of the decoder for MAE pretraining.
+ decoder_num_heads (`int`, *optional*):
+ Number of attention heads in each layer of the decoder for MAE pretraining.
+ normalize_pixel_loss (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the pixel loss by the number of pixels.
+ mask_ratio (`float`, *optional*, defaults to 0.6):
+ The ratio of masked tokens in the input.
+ out_features (`list[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`list[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+
+
+ Example:
+
+ ```python
+ >>> from transformers import HieraConfig, HieraModel
+
+ >>> # Initializing a Hiera hiera-base-patch16-224 style configuration
+ >>> configuration = HieraConfig()
+
+ >>> # Initializing a model (with random weights) from the hiera-base-patch16-224 style configuration
+ >>> model = HieraModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "hiera"
+
+ attribute_map = {"num_hidden_layers": "num_layers"}
+
+ def __init__(
+ self,
+ embed_dim=96,
+ image_size=[224, 224],
+ patch_size=[7, 7],
+ patch_stride=[4, 4],
+ patch_padding=[3, 3],
+ mlp_ratio=4.0,
+ depths=[2, 3, 16, 3],
+ num_heads=[1, 2, 4, 8],
+ embed_dim_multiplier=2.0,
+ num_query_pool=3,
+ query_stride=[2, 2],
+ masked_unit_size=[8, 8],
+ masked_unit_attention=[True, True, False, False],
+ drop_path_rate=0.0,
+ num_channels=3,
+ hidden_act="gelu",
+ initializer_range=0.02,
+ layer_norm_init=1.0,
+ layer_norm_eps=1e-6,
+ decoder_hidden_size=None,
+ decoder_depth=None,
+ decoder_num_heads=None,
+ normalize_pixel_loss=True,
+ mask_ratio=0.6,
+ out_features=None,
+ out_indices=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if masked_unit_size[0] % query_stride[0] ** (len(depths) - 1) != 0:
+ raise ValueError(
+ f"masked_unit_size[0] ({masked_unit_size[0]}) must be divisible by query_stride[0] ({query_stride[0]}) "
+ f"raised to the power of the number of layers ({len(depths) - 1})"
+ )
+
+ if num_query_pool >= len(depths):
+ raise ValueError(
+ f"num_query_pool ({num_query_pool}) must be less than the number of layers ({len(depths)})"
+ )
+
+ self.embed_dim = embed_dim
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.patch_stride = patch_stride
+ self.patch_padding = patch_padding
+ self.mlp_ratio = mlp_ratio
+ self.depths = depths
+ self.num_heads = num_heads
+ self.num_layers = len(depths)
+ self.embed_dim_multiplier = embed_dim_multiplier
+ self.num_query_pool = num_query_pool
+ self.query_stride = query_stride
+ self.masked_unit_size = masked_unit_size
+ self.masked_unit_attention = masked_unit_attention
+ self.drop_path_rate = drop_path_rate
+ self.num_channels = num_channels
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.layer_norm_init = layer_norm_init
+ self.layer_norm_eps = layer_norm_eps
+ self.decoder_hidden_size = decoder_hidden_size
+ self.decoder_depth = decoder_depth
+ self.decoder_num_heads = decoder_num_heads
+ self.normalize_pixel_loss = normalize_pixel_loss
+ self.mask_ratio = mask_ratio
+ # we set the hidden_size attribute in order to make Hiera work with VisionEncoderDecoderModel
+ # this indicates the channel dimension after the last stage of the model
+ self.hidden_size = int(embed_dim * embed_dim_multiplier ** (len(depths) - 1))
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+
+
+__all__ = ["HieraConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/hiera/modeling_hiera.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/hiera/modeling_hiera.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c084f0f836e18c4891491ea983676e9752e80a2
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/hiera/modeling_hiera.py
@@ -0,0 +1,1439 @@
+# coding=utf-8
+# Copyright 2024 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Hiera model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BackboneOutput,
+ BaseModelOutput,
+ BaseModelOutputWithPooling,
+ ImageClassifierOutput,
+ ModelOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging, torch_int
+from ...utils.backbone_utils import BackboneMixin
+from .configuration_hiera import HieraConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Hiera encoder's outputs, with potential hidden states and attentions.
+ """
+)
+class HieraEncoderOutput(ModelOutput):
+ r"""
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Hiera model's outputs that also contains a pooling of the last hidden states.
+ """
+)
+class HieraModelOutput(ModelOutput):
+ r"""
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
+ Average pooling of the last layer hidden-state.
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
+ Tensor indicating which patches are masked (0) and which are not (1).
+ ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Tensor containing the original index of the (shuffled) masked patches.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ bool_masked_pos: Optional[torch.BoolTensor] = None
+ ids_restore: Optional[torch.LongTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Hiera image classification outputs.
+ """
+)
+class HieraForImageClassificationOutput(ImageClassifierOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, `optional`):
+ Loss value for the training task.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
+ Prediction scores of the classification head (logits of the output layer).
+ hidden_states (`tuple(torch.FloatTensor)`, `optional`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`. These are the unrolled hidden states of the model.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, `optional`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, `optional`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Class for HieraForPreTraining's outputs, with potential hidden states and attentions.
+ """
+)
+class HieraForPreTrainingOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`):
+ Pixel reconstruction loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
+ Pixel reconstruction logits.
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
+ Tensor indicating which patches are masked (0) and which are not (1).
+ ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Tensor containing the original index of the (shuffled) masked patches.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, height, width, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs reshaped to include the spatial dimensions.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ bool_masked_pos: Optional[torch.BoolTensor] = None
+ ids_restore: Optional[torch.LongTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ reshaped_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+class HieraPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config, is_mae: bool = False):
+ super().__init__()
+
+ # Support any number of spatial dimensions
+ self.spatial_dims = len(config.patch_size)
+ if self.spatial_dims != 2:
+ raise ValueError(f"The number of dimensions of the input image should be 2, but got {self.spatial_dims}.")
+ self.num_channels = config.num_channels
+ self.image_size = config.image_size[-2:]
+ self.tokens_spatial_shape = [i // s for i, s in zip(config.image_size, config.patch_stride)]
+ self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, config.masked_unit_size)]
+ self.mask_ratio = config.mask_ratio
+ self.is_mae = is_mae
+ self.projection = nn.Conv2d(
+ self.num_channels,
+ config.embed_dim,
+ kernel_size=config.patch_size,
+ stride=config.patch_stride,
+ padding=config.patch_padding,
+ )
+
+ def masked_conv(
+ self, pixel_values: torch.FloatTensor, bool_masked_pos: Optional[torch.BoolTensor] = None
+ ) -> torch.Tensor:
+ """Zero-out the masked regions of the input before conv.
+ Prevents leakage of masked regions when using overlapping kernels.
+ """
+ if bool_masked_pos is None:
+ return self.projection(pixel_values)
+
+ target_size = pixel_values.shape[2:]
+ # Reshape bool_masked_pos to (batch_size, 1, mask_unit_height, mask_unit_width)
+ bool_masked_pos = bool_masked_pos.view(pixel_values.shape[0], 1, *self.mask_spatial_shape)
+
+ bool_masked_pos = nn.functional.interpolate(bool_masked_pos.float(), size=target_size)
+
+ return self.projection(pixel_values * bool_masked_pos)
+
+ def random_masking(
+ self, pixel_values: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None
+ ) -> tuple[torch.BoolTensor, torch.LongTensor]:
+ """
+ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
+ noise.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`)
+ noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
+ mainly used for testing purposes to control randomness and maintain the reproducibility
+ """
+ batch_size = pixel_values.shape[0]
+ # Tokens selected for masking at mask unit level
+ num_windows = math.prod(self.mask_spatial_shape)
+ len_keep = int(num_windows * (1 - self.mask_ratio))
+
+ if noise is None:
+ noise = torch.rand(batch_size, num_windows, device=pixel_values.device)
+
+ # Sort noise for each sample
+ ids_shuffle = torch.argsort(noise, dim=1)
+ # ascend: small is keep, large is remove
+ ids_restore = torch.argsort(ids_shuffle, dim=1).to(pixel_values.device)
+
+ # Generate the binary bool_masked_pos: 1 is *keep*, 0 is *remove*
+ # Note this is opposite to original MAE
+ bool_masked_pos = torch.zeros([batch_size, num_windows], device=pixel_values.device)
+ bool_masked_pos[:, :len_keep] = 1
+ # Unshuffle to get the binary bool_masked_pos
+ bool_masked_pos = torch.gather(bool_masked_pos, dim=1, index=ids_restore).bool()
+
+ return bool_masked_pos, ids_restore
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ noise: Optional[torch.FloatTensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.BoolTensor], Optional[torch.LongTensor]]:
+ (bool_masked_pos, ids_restore) = (
+ self.random_masking(pixel_values, noise=noise) if self.is_mae else (None, None)
+ )
+
+ embeddings = self.masked_conv(pixel_values, bool_masked_pos)
+ embeddings = embeddings.flatten(2).transpose(2, 1)
+
+ return embeddings, bool_masked_pos, ids_restore
+
+
+class HieraEmbeddings(nn.Module):
+ """
+ Construct position and patch embeddings.
+ """
+
+ def __init__(self, config: HieraConfig, is_mae: bool = False) -> None:
+ super().__init__()
+ self.patch_stride = config.patch_stride
+ tokens_spatial_shape = [i // s for i, s in zip(config.image_size, config.patch_stride)]
+ self.mask_spatial_shape = [i // s for i, s in zip(tokens_spatial_shape, config.masked_unit_size)]
+ self.num_tokens = math.prod(tokens_spatial_shape)
+ self.is_mae = is_mae
+
+ self.patch_embeddings = HieraPatchEmbeddings(config, is_mae=is_mae)
+
+ self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_tokens, config.embed_dim))
+
+ def interpolate_pos_encoding(
+ self, embeddings: torch.Tensor, pos_embeds: torch.Tensor, height: int, width: int
+ ) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing, no class embeddings, and different patch strides.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1]
+ num_positions = pos_embeds.shape[1]
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return pos_embeds
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_stride[0]
+ new_width = width // self.patch_stride[1]
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ pos_embeds = pos_embeds.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ pos_embeds = pos_embeds.permute(0, 3, 1, 2)
+
+ pos_embeds = nn.functional.interpolate(
+ pos_embeds,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ pos_embeds = pos_embeds.permute(0, 2, 3, 1).view(1, -1, dim)
+ return pos_embeds
+
+ def get_position_embedding(
+ self, embeddings: torch.Tensor, height: int, width: int, interpolate_pos_encoding: bool
+ ) -> torch.FloatTensor:
+ return (
+ self.interpolate_pos_encoding(embeddings, self.position_embeddings, height, width)
+ if interpolate_pos_encoding
+ else self.position_embeddings
+ )
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ noise: Optional[torch.FloatTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> tuple[torch.Tensor, Optional[torch.BoolTensor], Optional[torch.LongTensor]]:
+ height, width = pixel_values.shape[-2:]
+ embeddings, bool_masked_pos, ids_restore = self.patch_embeddings(pixel_values, noise=noise)
+ embeddings = embeddings + self.get_position_embedding(embeddings, height, width, interpolate_pos_encoding)
+ return embeddings, bool_masked_pos, ids_restore
+
+
+class HieraMaskUnitAttention(nn.Module):
+ """
+ Computes either Mask Unit or Global Attention. Also is able to perform query pooling.
+
+ Note: this assumes the tokens have already been flattened and unrolled into mask units.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ hidden_size_output: int,
+ num_heads: int,
+ query_stride: int = 1,
+ window_size: int = 0,
+ use_mask_unit_attn: bool = False,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ self.query_stride = query_stride
+ self.hidden_size_output = hidden_size_output
+
+ self.head_dim = hidden_size_output // num_heads
+ self.scale = (self.head_dim) ** -0.5
+
+ self.qkv = nn.Linear(hidden_size, 3 * hidden_size_output)
+ self.proj = nn.Linear(hidden_size_output, hidden_size_output)
+
+ self.window_size = window_size
+ self.use_mask_unit_attn = use_mask_unit_attn
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input should be of shape [batch, tokens, channels]."""
+ batch_size, seq_len, _ = hidden_states.shape
+
+ num_windows = 1
+ if self.use_mask_unit_attn:
+ num_windows = seq_len // (self.query_stride * self.window_size)
+
+ qkv = self.qkv(hidden_states)
+ qkv = qkv.reshape(batch_size, -1, num_windows, 3, self.num_heads, self.head_dim)
+ qkv = qkv.permute(3, 0, 4, 2, 1, 5)
+
+ query, key, value = qkv.unbind(0)
+
+ if self.query_stride > 1:
+ # Refer to unroll to see how this performs a maxpool-Nd
+ query = query.view(batch_size, self.num_heads, num_windows, self.query_stride, -1, self.head_dim)
+ query = query.max(dim=3).values
+
+ attn_weights = (query * self.scale) @ key.transpose(-1, -2)
+ attn_weights = attn_weights.softmax(dim=-1)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = attn_weights @ value
+ attn_output = attn_output.transpose(1, 3).reshape(batch_size, -1, self.hidden_size_output)
+ attn_output = self.proj(attn_output)
+
+ return (attn_output, attn_weights) if output_attentions else (attn_output, None)
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Hiera
+class HieraDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return f"p={self.drop_prob}"
+
+
+class HieraMlp(nn.Module):
+ def __init__(self, config, dim: int) -> None:
+ super().__init__()
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(dim, int(dim * config.mlp_ratio))
+ self.fc2 = nn.Linear(int(dim * config.mlp_ratio), dim)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class HieraLayer(nn.Module):
+ def __init__(
+ self,
+ config,
+ hidden_size: int,
+ hidden_size_output: int,
+ num_heads: int,
+ drop_path: float = 0.0,
+ query_stride: int = 1,
+ window_size: int = 0,
+ use_mask_unit_attn: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.hidden_size_output = hidden_size_output
+ self.query_stride = query_stride
+
+ self.layernorm_before = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+ self.attn = HieraMaskUnitAttention(
+ hidden_size=hidden_size,
+ hidden_size_output=hidden_size_output,
+ num_heads=num_heads,
+ query_stride=query_stride,
+ window_size=window_size,
+ use_mask_unit_attn=use_mask_unit_attn,
+ )
+
+ self.layernorm_after = nn.LayerNorm(hidden_size_output, eps=config.layer_norm_eps)
+ self.mlp = HieraMlp(config, hidden_size_output)
+
+ self.drop_path = HieraDropPath(drop_path) if drop_path > 0 else nn.Identity()
+ if hidden_size != hidden_size_output:
+ self.proj = nn.Linear(hidden_size, hidden_size_output)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ batch_size, seq_len, _ = hidden_states.shape
+ # Attention + Q Pooling
+ hidden_states_norm = self.layernorm_before(hidden_states)
+ if self.hidden_size != self.hidden_size_output:
+ hidden_states = self.proj(hidden_states_norm)
+ # Refer to unroll to see how this performs a maxpool-Nd
+ hidden_states = (
+ hidden_states.view(batch_size, self.query_stride, -1, self.hidden_size_output).max(dim=1).values
+ )
+
+ (hidden_states_norm, attn_weights) = self.attn(
+ hidden_states_norm, head_mask, output_attentions=output_attentions
+ )
+ hidden_states = hidden_states + self.drop_path(hidden_states_norm)
+
+ residual = hidden_states
+ hidden_states = self.layernorm_after(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + self.drop_path(hidden_states)
+
+ return (hidden_states, attn_weights)
+
+
+class HieraStage(GradientCheckpointingLayer):
+ def __init__(
+ self,
+ config,
+ depth: int,
+ hidden_size: int,
+ hidden_size_output: int,
+ num_heads: int,
+ drop_path: list[float],
+ query_stride: list[int],
+ window_size: int,
+ use_mask_unit_attn: bool,
+ stage_num: Optional[int] = None,
+ ) -> None:
+ super().__init__()
+ # we need to know if the previous stage used masked attention
+ # mask unit or global attention.
+ # lag by 1 layer, so that global attention,
+ # applied post pooling on lower resolution
+ previous_stage_used_masked_attention = False
+ if stage_num is not None:
+ previous_stage_used_masked_attention = config.masked_unit_attention[stage_num - 1 if stage_num > 0 else 0]
+ self.layers = nn.ModuleList(
+ [
+ HieraLayer(
+ config=config,
+ hidden_size=hidden_size if i == 0 else hidden_size_output,
+ hidden_size_output=hidden_size_output,
+ num_heads=num_heads,
+ drop_path=drop_path[i],
+ query_stride=query_stride[i],
+ window_size=window_size,
+ use_mask_unit_attn=use_mask_unit_attn or (previous_stage_used_masked_attention and i == 0),
+ )
+ for i in range(depth)
+ ]
+ )
+
+ def forward(
+ self, hidden_states: torch.Tensor, head_mask: Optional[torch.FloatTensor], output_attentions: bool = False
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ for i, layer_module in enumerate(self.layers):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ (hidden_states, attn_weights) = layer_module(
+ hidden_states, layer_head_mask, output_attentions=output_attentions
+ )
+
+ return hidden_states, attn_weights
+
+
+def undo_windowing(hidden_states: torch.Tensor, shape: list[int], mask_unit_shape: list[int]) -> torch.Tensor:
+ """
+ Restore spatial organization by undoing windowed organization of mask units.
+
+ Args:
+ hidden_states (`torch.Tensor`): The hidden states tensor of shape `[batch_size, num_mask_unit_height*num_mask_unit_width, hidden_size]`.
+ shape (`list[int]`): The original shape of the hidden states tensor before windowing.
+ mask_unit_shape (`list[int]`): The shape of the mask units used for windowing.
+
+ Returns:
+ torch.Tensor: The restored hidden states tensor of shape [batch_size, num_mask_unit_height*mask_unit_height, num_mask_unit_width*mask_unit_width, hidden_size].
+ """
+ batch_size, hidden_size = hidden_states.shape[0], hidden_states.shape[-1]
+ # From: [batch_size, num_mask_unit_height*num_mask_unit_width, hidden_size]
+ # To: [batch_size, num_mask_unit_height, num_mask_unit_width, mask_unit_height, mask_unit_width, hidden_size]
+ num_mask_units = [s // mu for s, mu in zip(shape, mask_unit_shape)]
+ hidden_states = hidden_states.view(batch_size, *num_mask_units, *mask_unit_shape, hidden_size)
+
+ # From: [batch_size, num_mask_unit_height, num_mask_unit_width, mask_unit_height, mask_unit_width, hidden_size]
+ # To: [batch_size, num_mask_unit_height*mask_unit_height, num_mask_unit_width*mask_unit_width, hidden_size]
+ hidden_states = hidden_states.permute(0, 1, 3, 2, 4, 5)
+ hidden_states = hidden_states.reshape(batch_size, *shape, hidden_size)
+
+ return hidden_states
+
+
+class HieraEncoder(nn.Module):
+ def __init__(self, config: HieraConfig) -> None:
+ super().__init__()
+ total_depth = sum(config.depths)
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, total_depth, device="cpu")]
+ # query strides rule
+ cumulative_depths = torch.tensor(config.depths, device="cpu").cumsum(0).tolist()
+ query_pool_layer = cumulative_depths[: config.num_query_pool]
+ query_strides = [math.prod(config.query_stride) if i in query_pool_layer else 1 for i in range(total_depth)]
+
+ # Transformer blocks
+ self.stages = nn.ModuleList()
+ hidden_size = config.embed_dim
+ stage_ends = [0] + cumulative_depths
+ masked_unit_area = math.prod(config.masked_unit_size)
+ query_stride_area = math.prod(config.query_stride)
+ for idx_stage, depth in enumerate(config.depths):
+ hidden_size_output = int(config.embed_dim * config.embed_dim_multiplier**idx_stage)
+
+ stage = HieraStage(
+ config=config,
+ depth=depth,
+ hidden_size=hidden_size,
+ hidden_size_output=hidden_size_output,
+ num_heads=config.num_heads[idx_stage],
+ drop_path=dpr[stage_ends[idx_stage] : stage_ends[idx_stage + 1]],
+ query_stride=query_strides[stage_ends[idx_stage] : stage_ends[idx_stage + 1]],
+ window_size=int(masked_unit_area * query_stride_area**-idx_stage),
+ use_mask_unit_attn=config.masked_unit_attention[idx_stage],
+ stage_num=idx_stage,
+ )
+
+ hidden_size = hidden_size_output
+ self.stages.append(stage)
+
+ # Setting reroll schedule
+ # The first stage has to reverse everything
+ # The next stage has to reverse all but the first unroll, etc.
+ stage_size = [i // s for i, s in zip(config.image_size, config.patch_stride)]
+ unroll_schedule = [config.query_stride] * len(config.depths[:-1])
+
+ self.schedule = {}
+ for idx_stage in range(len(config.depths)):
+ self.schedule[idx_stage] = unroll_schedule, stage_size
+ if idx_stage < config.num_query_pool:
+ stage_size = [i // s for i, s in zip(stage_size, config.query_stride)]
+ unroll_schedule = unroll_schedule[1:]
+
+ self.gradient_checkpointing = False
+
+ def reroll(
+ self, hidden_states: torch.Tensor, stage_idx: int, bool_masked_pos: Optional[torch.BoolTensor] = None
+ ) -> torch.Tensor:
+ """
+ Roll the given tensor back up to spatial order assuming it's from the given block.
+
+ If no bool_masked_pos is provided returns:
+ - [batch_size, height, width, hidden_size]
+ If a bool_masked_pos is provided returns:
+ - [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
+ """
+ schedule, size = self.schedule[stage_idx]
+ batch_size, seq_len, hidden_size = hidden_states.shape
+
+ num_dim = len(size)
+ mask_unit_shape = [1] * num_dim
+
+ for strides in schedule:
+ # Extract the current patch from seq_len
+ hidden_states = hidden_states.view(
+ batch_size, *strides, seq_len // math.prod(strides), *mask_unit_shape, hidden_size
+ )
+
+ # Move that patch into the current MU
+ # Input: [batch_size, stride, stride, seq_len//(stride*stride), mask_unit_height, mask_unit_width, hidden_size]
+ # Output: [batch_size, seq_len//(stride*stride), stride, mask_unit_height, stride, mask_unit_width, hidden_size]
+ hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5, 6)
+
+ # Reshape to [batch_size, seq_len//(stride*stride), *mask_units, hidden_size]
+ for i in range(num_dim):
+ mask_unit_shape[i] *= strides[i]
+ hidden_states = hidden_states.reshape(batch_size, -1, *mask_unit_shape, hidden_size)
+ seq_len = hidden_states.shape[1]
+
+ # Current shape (e.g., 2d: [batch_size, #num_mask_units_height*#num_mask_units_width, mask_unit_height, mask_unit_width, hidden_size])
+ hidden_states = hidden_states.view(batch_size, seq_len, *mask_unit_shape, hidden_size)
+
+ # If masked, return [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
+ if bool_masked_pos is not None:
+ return hidden_states
+
+ # If not masked, we can return [batch_size, height, width, hidden_size]
+ hidden_states = undo_windowing(hidden_states, size, mask_unit_shape)
+
+ return hidden_states
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_reshaped_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ reshaped_hidden_states = self.reroll(hidden_states, stage_idx=0, bool_masked_pos=bool_masked_pos)
+ all_reshaped_hidden_states = all_reshaped_hidden_states + (reshaped_hidden_states,)
+
+ for i, stage_module in enumerate(self.stages):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ reshaped_hidden_states = self.reroll(hidden_states, stage_idx=i, bool_masked_pos=bool_masked_pos)
+ all_reshaped_hidden_states = all_reshaped_hidden_states + (reshaped_hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, all_hidden_states, all_self_attentions, all_reshaped_hidden_states]
+ if v is not None
+ )
+ return HieraEncoderOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ reshaped_hidden_states=all_reshaped_hidden_states,
+ )
+
+
+def unroll(
+ hidden_states: torch.Tensor, image_shape: tuple[int, int], patch_stride: tuple[int, int], schedule: list[list[int]]
+) -> torch.Tensor:
+ """
+ Reorders the tokens such that patches are contiguous in memory.
+ E.g., given [batch_size, (height, width), hidden_size] and stride of (stride, stride), this will re-order the tokens as
+ [batch_size, (stride, stride, height // stride, width // stride), hidden_size]
+
+ This allows operations like Max2d to be computed as x.view(batch_size, stride*stride, -1, hidden_size).max(dim=1).
+ Not only is this faster, but it also makes it easy to support inputs of arbitrary
+ dimensions in addition to patch-wise sparsity.
+
+ Performing this operation multiple times in sequence puts entire windows as contiguous
+ in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
+ size 8x8 would be contiguous in memory, allowing operations like mask unit attention
+ computed easily and efficiently, while also allowing max to be applied sequentially.
+
+ Note: This means that intermediate values of the model are not in height x width order, so they
+ need to be re-rolled if you want to use the intermediate values as a height x width feature map.
+ The last block of the network is fine though, since by then the strides are all consumed.
+ """
+ batch_size, _, hidden_size = hidden_states.shape
+
+ size = [i // s for i, s in zip(image_shape, patch_stride)]
+
+ current_size = size
+ hidden_states = hidden_states.view(*([batch_size] + current_size + [hidden_size]))
+
+ for strides in schedule:
+ # Move patches with the given strides to the batch dimension
+
+ # Create a view of the tensor with the patch stride as separate dims
+ # For example in 2d: [batch_size, height // stride, stride, width // stride, stride, C]
+ current_size = [i // s for i, s in zip(current_size, strides)]
+ # initialize new_shape with [height // stride, stride, width // stride, stride]
+ new_shape = [item for pair in zip(current_size, strides) for item in pair]
+ # add batch_size and hidden_size to new_shape
+ new_shape = [batch_size] + new_shape + [hidden_size]
+ hidden_states = hidden_states.view(new_shape)
+
+ # Move the patch stride into the batch dimension
+ # For example in 2d: [batch_size, stride, stride, height // stride, width // stride, hidden_size]
+ num_dims = len(new_shape)
+ permute = [0] + list(range(2, num_dims - 1, 2)) + list(range(1, num_dims - 1, 2)) + [num_dims - 1]
+ hidden_states = hidden_states.permute(permute)
+
+ # Now finally flatten the relevant dims into the batch dimension
+ hidden_states = hidden_states.flatten(0, len(strides))
+ batch_size *= math.prod(strides)
+
+ hidden_states = hidden_states.reshape(-1, math.prod(size), hidden_size)
+ return hidden_states
+
+
+@auto_docstring
+class HieraPreTrainedModel(PreTrainedModel):
+ config: HieraConfig
+ base_model_prefix = "hiera"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module) -> None:
+ """Initialize the weights"""
+ std = self.config.initializer_range
+
+ if isinstance(module, HieraEmbeddings):
+ nn.init.trunc_normal_(module.position_embeddings, std=std)
+
+ elif isinstance(module, HieraDecoder):
+ nn.init.trunc_normal_(module.mask_token, std=std)
+ nn.init.trunc_normal_(module.decoder_position_embeddings, std=std)
+
+ elif isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
+ nn.init.trunc_normal_(module.weight, std=std)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, std)
+
+ elif isinstance(module, nn.LayerNorm):
+ nn.init.constant_(module.bias, std)
+ nn.init.constant_(module.weight, self.config.layer_norm_init)
+
+
+class HieraPooler(nn.Module):
+ def __init__(self, config: HieraConfig):
+ super().__init__()
+ num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
+ self.layernorm = nn.LayerNorm(num_features, eps=config.layer_norm_eps)
+ self.pooler = nn.AdaptiveAvgPool1d(1)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = hidden_states.transpose(1, 2)
+ pooled_output = self.pooler(hidden_states)
+ pooled_output = torch.flatten(pooled_output, 1)
+ pooled_output = self.layernorm(pooled_output)
+ return pooled_output
+
+
+@auto_docstring
+class HieraModel(HieraPreTrainedModel):
+ def __init__(self, config: HieraConfig, add_pooling_layer: bool = True, is_mae: bool = False):
+ r"""
+ add_pooling_layer (`bool`, *optional*, defaults to `True`):
+ Whether or not to apply pooling layer.
+ is_mae (`bool`, *optional*, defaults to `False`):
+ Whether or not to run the model on MAE mode.
+ """
+ super().__init__(config)
+ self.num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
+
+ self.embeddings = HieraEmbeddings(config, is_mae=is_mae)
+ self.encoder = HieraEncoder(config)
+
+ self.unroll_schedule = [config.query_stride] * len(config.depths[:-1])
+
+ self.pooler = HieraPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> HieraPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ noise: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ r"""
+ noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*):
+ Mainly used for testing purposes to control randomness and maintain the reproducibility
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, len(self.config.depths))
+
+ embedding_output, bool_masked_pos, ids_restore = self.embeddings(
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, noise=noise
+ )
+
+ image_shape = (pixel_values.shape[-2], pixel_values.shape[-1])
+ hidden_states = unroll(
+ embedding_output,
+ image_shape=image_shape,
+ patch_stride=self.config.patch_stride,
+ schedule=self.unroll_schedule,
+ )
+
+ # Discard masked tokens if bool_masked_pos is provided
+ if bool_masked_pos is not None:
+ mask_unit_area = math.prod(self.config.masked_unit_size)
+ batch_size, _, hidden_size = hidden_states.shape
+ positions = bool_masked_pos.unsqueeze(-1).tile(1, mask_unit_area, hidden_size)
+ hidden_states = hidden_states[positions]
+ hidden_states = hidden_states.view(batch_size, -1, hidden_size)
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = None
+ if self.pooler is not None:
+ pooled_output = self.pooler(sequence_output)
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+ head_outputs = (
+ head_outputs + (bool_masked_pos, ids_restore) if bool_masked_pos is not None else head_outputs
+ )
+ return head_outputs + encoder_outputs[1:]
+
+ return HieraModelOutput(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ bool_masked_pos=bool_masked_pos,
+ ids_restore=ids_restore,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
+ )
+
+
+class HieraDecoder(nn.Module):
+ def __init__(self, config: HieraConfig):
+ super().__init__()
+ num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
+ tokens_spatial_shape = [i // s for i, s in zip(config.image_size, config.patch_stride)]
+ self.tokens_spatial_shape_final = [
+ i // s ** (config.num_query_pool) for i, s in zip(tokens_spatial_shape, config.query_stride)
+ ]
+ self.mask_unit_spatial_shape_final = [
+ i // s ** (config.num_query_pool) for i, s in zip(config.masked_unit_size, config.query_stride)
+ ]
+
+ self.decoder_embeddings = nn.Linear(num_features, config.decoder_hidden_size)
+
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
+
+ self.decoder_position_embeddings = nn.Parameter(
+ torch.zeros(1, math.prod(self.tokens_spatial_shape_final), config.decoder_hidden_size)
+ )
+
+ self.decoder_block = HieraStage(
+ config=config,
+ hidden_size=config.decoder_hidden_size,
+ hidden_size_output=config.decoder_hidden_size,
+ num_heads=config.decoder_num_heads,
+ depth=config.decoder_depth,
+ use_mask_unit_attn=False,
+ drop_path=[0.0] * config.decoder_depth,
+ query_stride=[1] * config.decoder_depth,
+ window_size=0,
+ )
+
+ self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
+
+ # patch stride of prediction
+ self.pred_stride = config.patch_stride[-1] * (config.query_stride[-1] ** config.num_query_pool)
+ pred_dim = (self.pred_stride ** len(config.query_stride)) * config.num_channels
+
+ self.decoder_pred = nn.Linear(config.decoder_hidden_size, pred_dim)
+
+ def forward(
+ self,
+ encoder_hidden_states: torch.Tensor,
+ bool_masked_pos: torch.BoolTensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, torch.BoolTensor]:
+ # Embed tokens
+ hidden_states = self.decoder_embeddings(encoder_hidden_states)
+
+ # Combine visible and bool_masked_pos tokens
+
+ # hidden_states : [batch_size, num_mask_units_visible, *mask_unit_spatial_shape_final, decoder_hidden_size]
+ # bool_masked_pos: [batch_size, num_mask_units]
+ mask_unit_height, mask_unit_width, decoder_hidden_size = hidden_states.shape[2:]
+ batch_size, num_mask_units = bool_masked_pos.shape
+
+ decoder_hidden_states = torch.zeros(
+ batch_size,
+ num_mask_units,
+ mask_unit_height,
+ mask_unit_width,
+ decoder_hidden_size,
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+ mask_tokens = self.mask_token.view(1, 1, 1, 1, -1)
+ bool_masked_pos = bool_masked_pos.reshape(batch_size, num_mask_units, 1, 1, 1)
+ bool_masked_pos = bool_masked_pos.expand(-1, -1, mask_unit_height, mask_unit_width, decoder_hidden_size)
+ decoder_hidden_states[bool_masked_pos] = hidden_states.flatten()
+ decoder_hidden_states = (
+ 1 - bool_masked_pos.float()
+ ) * mask_tokens + bool_masked_pos.float() * decoder_hidden_states
+
+ # Get back spatial order
+ hidden_states = undo_windowing(
+ decoder_hidden_states,
+ self.tokens_spatial_shape_final,
+ self.mask_unit_spatial_shape_final,
+ )
+ bool_masked_pos = undo_windowing(
+ bool_masked_pos[..., 0:1],
+ self.tokens_spatial_shape_final,
+ self.mask_unit_spatial_shape_final,
+ )
+
+ # Flatten
+ hidden_states = hidden_states.reshape(hidden_states.shape[0], -1, hidden_states.shape[-1])
+ bool_masked_pos = bool_masked_pos.view(hidden_states.shape[0], -1)
+
+ # Add pos embed
+ hidden_states = hidden_states + self.decoder_position_embeddings
+
+ # Apply decoder blocks
+ hidden_states, attn_weights = self.decoder_block(
+ hidden_states, head_mask=head_mask, output_attentions=output_attentions
+ )
+ hidden_states = self.decoder_norm(hidden_states)
+
+ # Predictor projection
+ hidden_states = self.decoder_pred(hidden_states)
+
+ return hidden_states, bool_masked_pos
+
+
+class HieraMultiScaleHead(nn.Module):
+ def __init__(self, config: HieraConfig):
+ super().__init__()
+ self.mask_unit_spatial_shape_final = [
+ i // s ** (config.num_query_pool) for i, s in zip(config.masked_unit_size, config.query_stride)
+ ]
+ self.stage_dimensions = [
+ int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(len(config.depths))
+ ]
+ current_masked_unit_size = config.masked_unit_size
+ self.multi_scale_fusion_heads = nn.ModuleList()
+
+ for idx in range(config.num_query_pool):
+ kernel = [i // s for i, s in zip(current_masked_unit_size, self.mask_unit_spatial_shape_final)]
+ current_masked_unit_size = [i // s for i, s in zip(current_masked_unit_size, config.query_stride)]
+ self.multi_scale_fusion_heads.append(
+ nn.Conv2d(
+ self.stage_dimensions[idx],
+ self.stage_dimensions[-1],
+ kernel_size=kernel,
+ stride=kernel,
+ )
+ )
+ self.multi_scale_fusion_heads.append(nn.Identity())
+
+ def apply_fusion_head(self, head: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
+ if isinstance(head, nn.Identity):
+ return hidden_states
+
+ # Doing explicit to avoid problems with torch.fx
+ batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size = hidden_states.shape
+ # From: [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
+ # To: head([batch_size * num_mask_units, hidden_size, mask_unit_height, mask_unit_width])
+ hidden_states = hidden_states.reshape(
+ batch_size * num_mask_units, mask_unit_height, mask_unit_width, hidden_size
+ )
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
+ hidden_states = head(hidden_states)
+
+ # Restore original layout
+ hidden_states = hidden_states.permute(0, 2, 3, 1)
+ mask_unit_height_final, mask_unit_width_final, hidden_size = hidden_states.shape[1:]
+ hidden_states = hidden_states.reshape(
+ batch_size, num_mask_units, mask_unit_height_final, mask_unit_width_final, hidden_size
+ )
+
+ return hidden_states
+
+ def forward(self, feature_maps: list[torch.Tensor]) -> torch.Tensor:
+ # Multi-scale fusion
+ hidden_states = 0.0
+ for head, feature_map in zip(self.multi_scale_fusion_heads, feature_maps):
+ hidden_states = hidden_states + self.apply_fusion_head(head, feature_map)
+
+ return hidden_states
+
+
+@auto_docstring(
+ custom_intro="""
+ The Hiera Model transformer with the decoder on top for self-supervised pre-training.
+
+
+
+ Note that we provide a script to pre-train this model on custom data in our [examples
+ directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
+
+
+ """
+)
+class HieraForPreTraining(HieraPreTrainedModel):
+ def __init__(self, config: HieraConfig) -> None:
+ super().__init__(config)
+ # Encoder
+ self.hiera = HieraModel(config, add_pooling_layer=False, is_mae=True)
+ self.encoder_norm = nn.LayerNorm(self.hiera.num_features, eps=config.layer_norm_eps)
+ # Multi-scale fusion heads
+ self.multiscale_fusion = HieraMultiScaleHead(config)
+ # Decoder
+ self.decoder = HieraDecoder(config)
+ self.pred_stride = self.decoder.pred_stride
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_pixel_label_2d(self, pixel_values: torch.Tensor, bool_masked_pos: torch.BoolTensor) -> torch.Tensor:
+ # bool_masked_pos (boolean tensor): True means *masked*
+ pixel_values = pixel_values.permute(0, 2, 3, 1)
+
+ size = self.pred_stride
+ label = pixel_values.unfold(1, size, size).unfold(2, size, size)
+ label = label.flatten(1, 2).flatten(2)
+ label = label[bool_masked_pos]
+ if self.config.normalize_pixel_loss:
+ mean = label.mean(dim=-1, keepdim=True)
+ var = label.var(dim=-1, keepdim=True)
+ label = (label - mean) / (var + 1.0e-6) ** 0.5
+
+ return label
+
+ def forward_loss(self, pixel_values: torch.Tensor, logits: torch.Tensor, bool_masked_pos: torch.BoolTensor):
+ # We invert the bool_masked_pos such that 1.0 is *masked*
+ bool_masked_pos = ~bool_masked_pos
+ label = self.get_pixel_label_2d(pixel_values, bool_masked_pos)
+
+ logits = logits[bool_masked_pos]
+ loss = (logits - label) ** 2
+ loss = loss.mean()
+
+ return loss
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ noise: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, HieraForPreTrainingOutput]:
+ r"""
+ noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*):
+ Mainly used for testing purposes to control randomness and maintain the reproducibility
+
+ Examples:
+ ```python
+ >>> from transformers import AutoImageProcessor, HieraForPreTraining
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/hiera-tiny-224-mae-hf")
+ >>> model = HieraForPreTraining.from_pretrained("facebook/hiera-tiny-224-mae-hf")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ >>> loss = outputs.loss
+ >>> print(list(logits.shape))
+ [1, 196, 768]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.hiera(
+ pixel_values,
+ noise=noise,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ feature_maps = outputs[-1]
+ bool_masked_pos = outputs[1]
+ ids_to_restore = outputs[2]
+ # Take only the query pooled and last hidden states
+ feature_maps = feature_maps[1 : self.hiera.config.num_query_pool + 1] + (feature_maps[-1],)
+ fused_hidden_states = self.multiscale_fusion(feature_maps)
+ fused_hidden_states = self.encoder_norm(fused_hidden_states)
+
+ # Reconstruct pixel values
+ logits, bool_masked_pos = self.decoder(
+ fused_hidden_states,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ )
+
+ loss = self.forward_loss(pixel_values, logits, bool_masked_pos)
+
+ if not return_dict:
+ output = (logits, bool_masked_pos, ids_to_restore)
+ if output_hidden_states:
+ output = output + (outputs[3],)
+ if output_attentions:
+ output = output + (outputs[4],)
+ if output_hidden_states:
+ output = output + (outputs[-1],)
+ return ((loss,) + output) if loss is not None else output
+
+ return HieraForPreTrainingOutput(
+ loss=loss,
+ logits=logits,
+ bool_masked_pos=bool_masked_pos,
+ ids_restore=ids_to_restore,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ reshaped_hidden_states=outputs.reshaped_hidden_states if output_hidden_states else None,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Hiera Model transformer with an image classification head on top (a linear layer on top of the final hidden state with
+ average pooling) e.g. for ImageNet.
+
+
+
+ Note that it's possible to fine-tune Hiera on higher resolution images than the ones it has been trained on, by
+ setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
+ position embeddings to the higher resolution.
+
+
+ """
+)
+class HieraForImageClassification(HieraPreTrainedModel):
+ def __init__(self, config: HieraConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.hiera = HieraModel(config, add_pooling_layer=True, is_mae=False)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(self.hiera.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, HieraForImageClassificationOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.hiera(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return HieraForImageClassificationOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ reshaped_hidden_states=outputs.reshaped_hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Hiera backbone, to be used with frameworks like DETR and MaskFormer.
+ """
+)
+class HieraBackbone(HieraPreTrainedModel, BackboneMixin):
+ def __init__(self, config: HieraConfig):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.num_features = [config.embed_dim] + [
+ int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(len(config.depths))
+ ]
+ self.embeddings = HieraEmbeddings(config, is_mae=False)
+ self.encoder = HieraEncoder(config)
+
+ # Add layer norms to hidden states of out_features
+ hidden_states_norms = {}
+ for stage, num_channels in zip(self._out_features, self.channels):
+ hidden_states_norms[stage] = nn.LayerNorm(num_channels)
+ self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BackboneOutput:
+ """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/hiera-tiny-224-hf")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "facebook/hiera-tiny-224-hf", out_features=["stage1", "stage2", "stage3", "stage4"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 768, 7, 7]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ embedding_output, _, _ = self.embeddings(pixel_values)
+
+ outputs = self.encoder(
+ embedding_output,
+ head_mask=None,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[-1]
+
+ feature_maps = ()
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ batch_size, height, width, num_channels = hidden_state.shape
+ hidden_state = hidden_state.view(batch_size, height * width, num_channels)
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
+ hidden_state = hidden_state.view(batch_size, height, width, num_channels)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ feature_maps += (hidden_state,)
+
+ if not return_dict:
+ output = (feature_maps,)
+ if output_hidden_states:
+ output += (outputs[1],)
+ if output_attentions:
+ output += (outputs[2],)
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs[1] if output_hidden_states else None,
+ attentions=outputs[2] if output_attentions else None,
+ )
+
+
+__all__ = ["HieraForImageClassification", "HieraForPreTraining", "HieraBackbone", "HieraModel", "HieraPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6715adc9ab8614bbb463b6b700fcdc2ca671d22
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_idefics2 import *
+ from .image_processing_idefics2 import *
+ from .image_processing_idefics2_fast import *
+ from .modeling_idefics2 import *
+ from .processing_idefics2 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/configuration_idefics2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/configuration_idefics2.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8fa442a1dbc67276864f033623c5526f3fed750
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/configuration_idefics2.py
@@ -0,0 +1,268 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Idefics2 model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class Idefics2VisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Idefics2VisionModel`]. It is used to instantiate a
+ Idefics2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the SigLIP checkpoint
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) used in the Idefics2 model
+ [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input images.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 32):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation for initializing all weight matrices in the model.
+
+ Example:
+
+ ```python
+ >>> from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
+ >>> from transformers.models.idefics2.configuration_idefics2 import Idefics2VisionConfig
+
+ >>> # Initializing a Idefics2VisionConfig with google/siglip-base-patch16-224 style configuration
+ >>> configuration = Idefics2VisionConfig()
+
+ >>> # Initializing a Idefics2VisionTransformer (with random weights) from the google/siglip-base-patch16-224 style configuration
+ >>> model = Idefics2VisionTransformer(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "idefics2_vision"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ intermediate_size=3072,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ num_channels=3,
+ image_size=224,
+ patch_size=32,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+
+
+class Idefics2PerceiverConfig(PretrainedConfig):
+ r"""
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the perceiver block.
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ resampler_n_latents (`int`, *optional*, defaults to 64):
+ Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
+ resampler_depth (`int`, *optional*, defaults to 3):
+ Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (<= 3).
+ resampler_n_heads (`int`, *optional*, defaults to 16):
+ Number of heads in each Transformer block (for multi-headed self-attention).
+ resampler_head_dim (`int`, *optional*, defaults to 96):
+ Dimensionality of each head projection in the Transformer block.
+ num_key_value_heads (`int`, *optional*, defaults to 4):
+ Number of key-value heads in the perceiver attention block.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation for initializing all weight matrices in the model.
+ """
+
+ model_type = "idefics2_perceiver"
+
+ def __init__(
+ self,
+ hidden_act="silu",
+ hidden_size=4096,
+ rms_norm_eps=1e-06,
+ resampler_n_latents=64,
+ resampler_depth=3,
+ resampler_n_heads=16,
+ resampler_head_dim=96,
+ num_key_value_heads=4,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ self.hidden_act = hidden_act
+ self.hidden_size = hidden_size
+ self.rms_norm_eps = rms_norm_eps
+ self.resampler_n_latents = resampler_n_latents
+ self.resampler_depth = resampler_depth
+ self.resampler_n_heads = resampler_n_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.resampler_head_dim = resampler_head_dim
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ if self.num_key_value_heads > self.resampler_n_heads:
+ raise ValueError(
+ f"num_key_value_heads={self.num_key_value_heads} must be less than or equal to"
+ f" resampler_n_heads={self.resampler_n_heads}"
+ )
+ super().__init__(**kwargs)
+
+
+class Idefics2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Idefics2Model`]. It is used to instantiate a
+ Idefics2 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the model of the Idefics2
+ [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should cache the key/value pairs of the attention mechanism.
+ image_token_id (`int`, *optional*, defaults to 32001):
+ The id of the "image" token.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to tie the word embeddings with the token embeddings.
+ vision_config (`IdeficsVisionConfig` or `dict`, *optional*):
+ Custom vision config or dict
+ perceiver_config (`IdeficsPerceiverConfig` or `dict`, *optional*):
+ Custom perceiver config or dict
+ text_config (`MistralConfig` or `dict`, *optional*):
+ Custom text config or dict for the text model
+
+ Example:
+ ```python
+ >>> from transformers import Idefics2Model, Idefics2Config
+ >>> # Initializing configuration
+ >>> configuration = Idefics2Config()
+ >>> # Initializing a model from the configuration
+ >>> model = Idefics2Model(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "idefics2"
+ sub_configs = {
+ "text_config": AutoConfig,
+ "perceiver_config": Idefics2PerceiverConfig,
+ "vision_config": Idefics2VisionConfig,
+ }
+
+ def __init__(
+ self,
+ use_cache=True,
+ image_token_id=32_001,
+ tie_word_embeddings=False,
+ vision_config=None,
+ perceiver_config=None,
+ text_config=None,
+ **kwargs,
+ ):
+ self.image_token_id = image_token_id
+ self.use_cache = use_cache
+ self.tie_word_embeddings = tie_word_embeddings
+
+ if perceiver_config is None:
+ self.perceiver_config = Idefics2PerceiverConfig()
+ logger.info("perciver_config is None, using default perceiver config")
+ elif isinstance(perceiver_config, dict):
+ self.perceiver_config = Idefics2PerceiverConfig(**perceiver_config)
+ elif isinstance(perceiver_config, Idefics2PerceiverConfig):
+ self.perceiver_config = perceiver_config
+
+ if vision_config is None:
+ self.vision_config = Idefics2VisionConfig()
+ logger.info("vision_config is None, using default vision config")
+ elif isinstance(vision_config, dict):
+ self.vision_config = Idefics2VisionConfig(**vision_config)
+ elif isinstance(vision_config, Idefics2VisionConfig):
+ self.vision_config = vision_config
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "mistral")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ logger.info("text_config is None, using default text config")
+ text_config = CONFIG_MAPPING["mistral"](
+ max_position_embeddings=4096 * 8,
+ rms_norm_eps=1e-5,
+ # None in the original configuration_mistral, we set it to the unk_token_id
+ pad_token_id=0,
+ tie_word_embeddings=False,
+ )
+
+ self.text_config = text_config
+ if self.text_config.hidden_size != self.perceiver_config.hidden_size:
+ self.perceiver_config.hidden_size = self.text_config.hidden_size
+ self.perceiver_config.rms_norm_eps = self.text_config.rms_norm_eps
+ logger.warning_once(
+ "Perceiver config has a different `hidden_size` than text config, which means default values were used. "
+ "In your model's config on the hub, add `hidden_size` and `rms_norm_eps` keys under the `perceiver_config` dict. "
+ )
+
+ super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
+
+
+__all__ = ["Idefics2Config"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/image_processing_idefics2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/image_processing_idefics2.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f0db7644563a22d78c849f2154f779193d36030
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/image_processing_idefics2.py
@@ -0,0 +1,572 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections.abc import Iterable
+from typing import Any, Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature
+from ...image_transforms import PaddingMode, pad, resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_nested_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, is_vision_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_vision_available():
+ import PIL
+ from PIL import Image
+
+
+def get_resize_output_image_size(image, size, input_data_format) -> tuple[int, int]:
+ """
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Size of the output image containing the keys "shortest_edge" and "longest_edge".
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ The output size of the image after resizing.
+ """
+ height, width = get_image_size(image, channel_dim=input_data_format)
+
+ min_len = size["shortest_edge"]
+ max_len = size["longest_edge"]
+ aspect_ratio = width / height
+
+ if width >= height and width > max_len:
+ width = max_len
+ height = int(width / aspect_ratio)
+ elif height > width and height > max_len:
+ height = max_len
+ width = int(height * aspect_ratio)
+ height = max(height, min_len)
+ width = max(width, min_len)
+ return height, width
+
+
+# Copied from transformers.models.detr.image_processing_detr.max_across_indices
+def max_across_indices(values: Iterable[Any]) -> list[Any]:
+ """
+ Return the maximum value across all indices of an iterable of values.
+ """
+ return [max(values_i) for values_i in zip(*values)]
+
+
+def get_max_height_width(
+ images_list: list[list[np.ndarray]], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> list[int]:
+ """
+ Get the maximum height and width across all images in a batch.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(images_list[0][0])
+
+ image_sizes = []
+ for images in images_list:
+ for image in images:
+ image_sizes.append(get_image_size(image, channel_dim=input_data_format))
+
+ max_height, max_width = max_across_indices(image_sizes)
+ return (max_height, max_width)
+
+
+# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
+def make_pixel_mask(
+ image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> np.ndarray:
+ """
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+
+ Args:
+ image (`np.ndarray`):
+ Image to make the pixel mask for.
+ output_size (`tuple[int, int]`):
+ Output size of the mask.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ mask = np.zeros(output_size, dtype=np.int64)
+ mask[:input_height, :input_width] = 1
+ return mask
+
+
+# FIXME Amy: merge this function with the one in image_transforms.py
+def convert_to_rgb(image: ImageInput) -> ImageInput:
+ """
+ Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
+ as is.
+ Args:
+ image (Image):
+ The image to convert.
+ """
+ if not isinstance(image, PIL.Image.Image):
+ return image
+
+ # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
+ # for transparent images. The call to `alpha_composite` handles this case
+ if image.mode == "RGB":
+ return image
+
+ image_rgba = image.convert("RGBA")
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
+ alpha_composite = Image.alpha_composite(background, image_rgba)
+ alpha_composite = alpha_composite.convert("RGB")
+ return alpha_composite
+
+
+class Idefics2ImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Idefics image processor.
+
+ Args:
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB. This is useful if the input image is of a different format e.g. RGBA.
+ Only has an effect if the input image is in the PIL format.
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image. The longest edge of the image is resized to be <= `size["longest_edge"]`, with the
+ shortest edge resized to keep the input aspect ratio, with a minimum size of `size["shortest_edge"]`.
+ size (`Dict`, *optional*):
+ Controls the size of the output image. This is a dictionary containing the keys "shortest_edge" and "longest_edge".
+ resample (`Resampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use when resizing the image.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image. If set to `True`, the image is rescaled to have pixel values between 0 and 1.
+ rescale_factor (`float`, *optional*, defaults to `1/255`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. If set to `True`, the image is normalized to have a mean of `image_mean` and
+ a standard deviation of `image_std`.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether or not to pad the images to the largest height and width in the batch and number of images per
+ sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width).
+ do_image_splitting (`bool`, *optional*, defaults to `False`):
+ Whether to split the image into a sequence 4 equal sub-images concatenated with the original image. That
+ strategy was first introduced in https://huggingface.co/papers/2311.06607.
+ """
+
+ model_input_names = ["pixel_values", "pixel_attention_mask"]
+
+ def __init__(
+ self,
+ do_convert_rgb: bool = True,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: bool = True,
+ do_image_splitting: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.do_convert_rgb = do_convert_rgb
+ self.do_resize = do_resize
+ self.size = size if size is not None else {"shortest_edge": 378, "longest_edge": 980}
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+ self.do_pad = do_pad
+ self.do_image_splitting = do_image_splitting
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
+ resized to keep the input aspect ratio.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ if "shortest_edge" in size and "longest_edge" in size:
+ size = get_resize_output_image_size(image, size, input_data_format)
+ elif "height" in size and "width" in size:
+ size = (size["height"], size["width"])
+ else:
+ raise ValueError(
+ "size must be a dictionary with keys 'shortest_edge' and 'longest_edge' or 'height' and 'width'."
+ )
+ return resize(
+ image, size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
+ )
+
+ # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor._pad_image
+ def _pad_image(
+ self,
+ image: np.ndarray,
+ output_size: tuple[int, int],
+ constant_values: Union[float, Iterable[float]] = 0,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pad an image with zeros to the given size.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ output_height, output_width = output_size
+
+ pad_bottom = output_height - input_height
+ pad_right = output_width - input_width
+ padding = ((0, pad_bottom), (0, pad_right))
+ padded_image = pad(
+ image,
+ padding,
+ mode=PaddingMode.CONSTANT,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ return padded_image
+
+ def pad(
+ self,
+ images: list[np.ndarray],
+ constant_values: Union[float, Iterable[float]] = 0,
+ return_pixel_mask: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> BatchFeature:
+ """
+ For a list of images, for each images, pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width.
+ For each sample in the batch, pads the sample with empty images to the max_number of images per sample in the batch. Optionally returns a pixel mask.
+
+ Args:
+ images (`np.ndarray`):
+ List of list of images to pad. Pads to the largest height and width in the batch.
+ constant_values (`float` or `Iterable[float]`, *optional*):
+ The value to use for the padding if `mode` is `"constant"`.
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
+ Whether to return a pixel mask.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ pad_size = get_max_height_width(images, input_data_format=input_data_format)
+
+ batch_size = len(images)
+ max_num_images = max(len(images_) for images_ in images)
+ input_data_format = (
+ infer_channel_dimension_format(images[0][0]) if input_data_format is None else input_data_format
+ )
+ data_format = input_data_format if data_format is None else data_format
+
+ def empty_image(size, input_data_format):
+ if input_data_format == ChannelDimension.FIRST:
+ return np.zeros((3, *size), dtype=np.uint8)
+ elif input_data_format == ChannelDimension.LAST:
+ return np.zeros((*size, 3), dtype=np.uint8)
+ raise ValueError("Invalid channel dimension format.")
+
+ padded_images_list = [
+ [empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size)
+ ]
+ padded_masks = [[np.zeros(pad_size) for _ in range(max_num_images)] for _ in range(batch_size)]
+
+ for batch_idx in range(batch_size):
+ for sample_idx, image in enumerate(images[batch_idx]):
+ padded_images_list[batch_idx][sample_idx] = self._pad_image(
+ image,
+ pad_size,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ padded_masks[batch_idx][sample_idx] = make_pixel_mask(
+ image, output_size=pad_size, input_data_format=input_data_format
+ )
+
+ padded_masks = padded_masks if return_pixel_mask else None
+ return padded_images_list, padded_masks
+
+ def _crop(
+ self,
+ im: np.ndarray,
+ w1: int,
+ h1: int,
+ w2: int,
+ h2: int,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ if input_data_format == ChannelDimension.FIRST:
+ return im[:, h1:h2, w1:w2]
+ elif input_data_format == ChannelDimension.LAST:
+ return im[h1:h2, w1:w2, :]
+
+ def split_image(
+ self,
+ image: np.ndarray,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Split an image into 4 equal sub-images, and the concatenate that sequence with the original image.
+ That means that a single image becomes a sequence of 5 images.
+ This is a "trick" to spend more compute on each image with no changes in the vision encoder.
+
+ Args:
+ image (`np.ndarray`):
+ Images to split.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ height, width = get_image_size(image, input_data_format)
+
+ mid_width = width // 2
+ mid_height = height // 2
+ return [
+ self._crop(image, 0, 0, mid_width, mid_height, input_data_format),
+ self._crop(image, mid_width, 0, width, mid_height, input_data_format),
+ self._crop(image, 0, mid_height, mid_width, height, input_data_format),
+ self._crop(image, mid_width, mid_height, width, height, input_data_format),
+ image,
+ ]
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_convert_rgb: Optional[bool] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ do_image_splitting: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ input_data_format: Optional[ChannelDimension] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ ):
+ """
+ Preprocess a batch of images.
+
+ Args:
+ images (`ImageInput`):
+ A list of images to preprocess.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether or not to pad the images to the largest height and width in the batch.
+ do_image_splitting (`bool`, *optional*, defaults to `self.do_image_splitting`):
+ Whether to split the image into a sequence 4 equal sub-images concatenated with the original image. That
+ strategy was first introduced in https://huggingface.co/papers/2311.06607.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ do_image_splitting = do_image_splitting if do_image_splitting is not None else self.do_image_splitting
+
+ images = self.fetch_images(images)
+ images_list = make_nested_list_of_images(images)
+
+ if not valid_images(images_list[0]):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if do_convert_rgb:
+ images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
+
+ # All transformations expect numpy arrays.
+ images_list = [[to_numpy_array(image) for image in images] for images in images_list]
+ # Search for the first image in the image list.
+ # NOTE: we can't slice the first image with images_list[0][0] if the first batch contains no images. See #36682
+ first_image_in_list = [images for images in images_list if images][0][0]
+
+ if do_rescale and is_scaled_image(first_image_in_list):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(first_image_in_list)
+
+ if do_image_splitting:
+ new_images_list = []
+ for images in images_list:
+ new_images = []
+ for image in images:
+ new_images.extend(self.split_image(image, input_data_format))
+ new_images_list.append(new_images)
+ images_list = new_images_list
+
+ if do_resize:
+ images_list = [
+ [
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+ for image in images
+ ]
+ for images in images_list
+ ]
+
+ if do_rescale:
+ images_list = [
+ [
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+ for image in images
+ ]
+ for images in images_list
+ ]
+
+ if do_normalize:
+ images_list = [
+ [
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+ for images in images_list
+ ]
+
+ pixel_attention_mask = None
+ if do_pad:
+ images_list, pixel_attention_mask = self.pad(
+ images_list, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=input_data_format
+ )
+
+ if data_format is not None:
+ images_list = [
+ [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ for image in images
+ ]
+ for images in images_list
+ ]
+
+ data = {"pixel_values": np.array(images_list) if do_pad else images_list} # Faster tensor conversion
+ if pixel_attention_mask is not None:
+ data["pixel_attention_mask"] = np.array(pixel_attention_mask) if do_pad else pixel_attention_mask
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["Idefics2ImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/image_processing_idefics2_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/image_processing_idefics2_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..5348bda389edb5b3ac2831cfa52264cd799334d1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/image_processing_idefics2_fast.py
@@ -0,0 +1,312 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional, Union
+
+import torch
+
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ BatchFeature,
+ DefaultFastImageProcessorKwargs,
+ SizeDict,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ImageInput,
+ PILImageResampling,
+ make_nested_list_of_images,
+)
+from ...processing_utils import Unpack
+from ...utils import TensorType, auto_docstring, is_torchvision_available, logging
+from .image_processing_idefics2 import convert_to_rgb
+
+
+if is_torchvision_available():
+ from torchvision.transforms import functional as F
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_resize_output_image_size(image: "torch.Tensor", size: SizeDict) -> tuple[int, int]:
+ """
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to resize.
+ size (`SizeDict`):
+ Size of the output image containing the keys "shortest_edge" and "longest_edge".
+
+ Returns:
+ The output size of the image after resizing.
+ """
+ height, width = image.size()[-2:]
+
+ min_len = size.shortest_edge
+ max_len = size.longest_edge
+ aspect_ratio = width / height
+
+ if width >= height and width > max_len:
+ width = max_len
+ height = int(width / aspect_ratio)
+ elif height > width and height > max_len:
+ height = max_len
+ width = int(height * aspect_ratio)
+ height = max(height, min_len)
+ width = max(width, min_len)
+ return height, width
+
+
+def get_max_height_width(images_list: list[list["torch.Tensor"]]) -> tuple[int, int]:
+ """
+ Get the maximum height and width across all images in a batch.
+ """
+ image_sizes = []
+ for images in images_list:
+ for image in images:
+ image_sizes.append(image.size()[-2:])
+
+ max_height = max(size[0] for size in image_sizes)
+ max_width = max(size[1] for size in image_sizes)
+ return (max_height, max_width)
+
+
+def make_pixel_mask(image: "torch.Tensor", output_size: tuple[int, int]) -> "torch.Tensor":
+ """
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to make the pixel mask for.
+ output_size (`Tuple[int, int]`):
+ Output size of the mask.
+ """
+ input_height, input_width = image.size()[-2:]
+ mask = torch.zeros(output_size, dtype=torch.int64, device=image.device)
+ mask[:input_height, :input_width] = 1
+ return mask
+
+
+class Idefics2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ do_image_splitting (`bool`, *optional*, defaults to `False`):
+ Whether to split the image into a sequence 4 equal sub-images concatenated with the original image.
+ """
+
+ do_image_splitting: Optional[bool]
+
+
+@auto_docstring
+class Idefics2ImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_pad = True
+ do_convert_rgb = True
+ do_image_splitting = False
+ size = {"shortest_edge": 378, "longest_edge": 980}
+ model_input_names = ["pixel_values", "pixel_attention_mask"]
+ valid_kwargs = Idefics2FastImageProcessorKwargs
+
+ def convert_to_rgb(self, image: ImageInput) -> ImageInput:
+ """
+ Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
+ as is.
+ """
+ return convert_to_rgb(image)
+
+ def resize(
+ self, image: torch.Tensor, size: SizeDict, interpolation: Optional["F.InterpolationMode"] = None, **kwargs
+ ) -> torch.Tensor:
+ """
+ Resize an image using torchvision's functional resize.
+ """
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
+
+ if size.shortest_edge and size.longest_edge:
+ new_size = get_resize_output_image_size(image, size)
+ elif size.height and size.width:
+ new_size = (size.height, size.width)
+ else:
+ raise ValueError("Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys.")
+
+ image = F.resize(image, size=new_size, interpolation=interpolation, **kwargs)
+ return image
+
+ def _prepare_images_structure(self, images: ImageInput, expected_ndims: int = 3) -> ImageInput:
+ """
+ Prepare a nested images structure for processing.
+ """
+ return make_nested_list_of_images(images, expected_ndims=expected_ndims)
+
+ def split_images(
+ self,
+ images: "torch.Tensor",
+ ) -> list["torch.Tensor"]:
+ """
+ Split a batch of images into 4 equal sub-images, and concatenate that sequence with the original image.
+ """
+ height, width = images.size()[-2:]
+
+ mid_width = width // 2
+ mid_height = height // 2
+
+ batch_split_images = [
+ images[..., :mid_height, :mid_width],
+ images[..., :mid_height, mid_width:],
+ images[..., mid_height:, :mid_width],
+ images[..., mid_height:, mid_width:],
+ images,
+ ]
+
+ # transpose the batch dimension to the first dimension
+ batch_split_images = [[image[i] for image in batch_split_images] for i in range(len(batch_split_images[0]))]
+ return batch_split_images
+
+ def pad(
+ self, image: "torch.Tensor", padded_size: tuple[int, int], fill: int = 0
+ ) -> tuple["torch.Tensor", "torch.Tensor"]:
+ """
+ Pad an image to the specified size and create the corresponding pixel mask.
+ """
+ original_size = image.shape[-2:]
+ padding_bottom = padded_size[0] - original_size[0]
+ padding_right = padded_size[1] - original_size[1]
+
+ if padding_bottom < 0 or padding_right < 0:
+ raise ValueError(
+ f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
+ f"original size. Got padded size: {padded_size}, original size: {original_size}."
+ )
+
+ # Only pad if necessary
+ if original_size != padded_size:
+ # torchvision's pad takes a 4-element tuple for 2D padding: (left, top, right, bottom)
+ padding = (0, 0, padding_right, padding_bottom)
+ # Use constant padding to match slow implementation
+ image = F.pad(image, padding, fill=fill, padding_mode="constant")
+
+ # Create pixel mask to match the slow implementation
+ pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device)
+ pixel_mask[: original_size[0], : original_size[1]] = 1
+
+ return image, pixel_mask
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[Idefics2FastImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def _preprocess(
+ self,
+ images: list[list["torch.Tensor"]],
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ do_pad: Optional[bool],
+ do_image_splitting: Optional[bool],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Process a batch of images for the model.
+ """
+ grouped_images, grouped_images_index = group_images_by_shape(
+ images, is_nested=True, disable_grouping=disable_grouping
+ )
+ split_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_image_splitting:
+ stacked_images = self.split_images(stacked_images)
+ split_images_grouped[shape] = stacked_images
+ split_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
+ if do_image_splitting:
+ # flattenened the doubly nested list to a nested list
+ for i, group_images in enumerate(split_images):
+ split_images[i] = [image for sublist in group_images for image in sublist]
+
+ # Group images by size for further processing
+ grouped_images, grouped_images_index = group_images_by_shape(
+ split_images, is_nested=True, disable_grouping=disable_grouping
+ )
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(stacked_images, size, interpolation=interpolation)
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index, is_nested=True)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(
+ resized_images, is_nested=True, disable_grouping=disable_grouping
+ )
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=True)
+
+ if do_pad:
+ # Get max images per batch
+ max_num_images = max(len(images_) for images_ in processed_images)
+ max_height, max_width = get_max_height_width(processed_images)
+
+ processed_images_padded = torch.zeros(
+ len(processed_images),
+ max_num_images,
+ *(processed_images[0][0].shape[0], max_height, max_width),
+ device=processed_images[0][0].device,
+ )
+ pixel_attention_masks = torch.zeros(
+ len(processed_images),
+ max_num_images,
+ *(max_height, max_width),
+ device=processed_images[0][0].device,
+ )
+ for i, images in enumerate(processed_images):
+ for j, image in enumerate(images):
+ processed_images_padded[i, j], pixel_attention_masks[i, j] = self.pad(
+ image, (max_height, max_width)
+ )
+ processed_images = processed_images_padded
+ if do_pad:
+ data = {"pixel_values": processed_images, "pixel_attention_mask": pixel_attention_masks}
+ elif return_tensors == "pt":
+ data = {"pixel_values": torch.stack([torch.stack(images) for images in processed_images])}
+ else:
+ data = {"pixel_values": processed_images}
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["Idefics2ImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/modeling_idefics2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/modeling_idefics2.py
new file mode 100644
index 0000000000000000000000000000000000000000..9703a43d605c66735702668c619b18408cda8154
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/modeling_idefics2.py
@@ -0,0 +1,1196 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Idefics2 model."""
+
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, ModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from ..auto import AutoModel
+from .configuration_idefics2 import Idefics2Config, Idefics2PerceiverConfig, Idefics2VisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Idefics2 model's outputs that may also contain a past key/values (to speed up sequential decoding).
+ """
+)
+class Idefics2BaseModelOutputWithPast(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+ sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Idefics2 causal language model (or autoregressive) outputs.
+ """
+)
+# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Idefics2
+class Idefics2CausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+ sequence_length, hidden_size)`.
+
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+class Idefics2VisionEmbeddings(nn.Module):
+ """
+ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
+ resolution.
+
+ The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://huggingface.co/papers/2307.06304)
+ which allows treating images in their native aspect ratio and without the need to resize them to the same
+ fixed size. In particular, we start from the original pre-trained SigLIP model
+ (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
+ """
+
+ def __init__(self, config: Idefics2VisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+
+ self.num_patches_per_side = self.image_size // self.patch_size
+ self.num_patches = self.num_patches_per_side**2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
+ boundaries = torch.arange(
+ 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
+ )
+ position_ids = torch.full(
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
+ )
+
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
+ nb_patches_h = p_attn_mask[:, 0].sum()
+ nb_patches_w = p_attn_mask[0].sum()
+
+ h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype)
+ w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype)
+
+ fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
+ fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
+
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
+
+ pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
+ position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
+
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ if hasattr(module, "num_key_value_groups"):
+ key = repeat_kv(key, module.num_key_value_groups)
+ value = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics2Vision
+class Idefics2VisionAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ # Ignore copy
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, seq_length, embed_dim = hidden_states.shape
+
+ queries = self.q_proj(hidden_states)
+ keys = self.k_proj(hidden_states)
+ values = self.v_proj(hidden_states)
+
+ queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ queries,
+ keys,
+ values,
+ attention_mask,
+ is_causal=self.is_causal,
+ scaling=self.scale,
+ dropout=0.0 if not self.training else self.dropout,
+ )
+
+ attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics2Vision
+class Idefics2VisionMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Idefics2MLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ output_size: int,
+ hidden_act: str,
+ ):
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.down_proj = nn.Linear(intermediate_size, output_size, bias=False)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+# Copied from transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead with Siglip->Idefics2
+class Idefics2MultiheadAttentionPoolingHead(nn.Module):
+ """Multihead Attention Pooling."""
+
+ def __init__(self, config: Idefics2VisionConfig):
+ super().__init__()
+
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ # Ignore copy
+ self.mlp = Idefics2MLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ output_size=config.hidden_size,
+ )
+
+ def forward(self, hidden_state):
+ batch_size = hidden_state.shape[0]
+ probe = self.probe.repeat(batch_size, 1, 1)
+
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
+
+ residual = hidden_state
+ hidden_state = self.layernorm(hidden_state)
+ hidden_state = residual + self.mlp(hidden_state)
+
+ return hidden_state[:, 0]
+
+
+class Idefics2EncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Idefics2VisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = Idefics2VisionAttention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = Idefics2VisionMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ @auto_docstring
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics2
+class Idefics2Encoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`Idefics2EncoderLayer`].
+
+ Args:
+ config: Idefics2Config
+ """
+
+ def __init__(self, config: Idefics2Config):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([Idefics2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ # Ignore copy
+ @auto_docstring
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutput:
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(
+ hidden_states,
+ attention_mask,
+ **kwargs,
+ )
+
+ return BaseModelOutput(last_hidden_state=hidden_states)
+
+
+@auto_docstring
+class Idefics2PreTrainedModel(PreTrainedModel):
+ config: Idefics2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
+
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+ elif isinstance(module, Idefics2RMSNorm):
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.MultiheadAttention):
+ module._reset_parameters() # native torch init
+ elif isinstance(module, Idefics2MultiheadAttentionPoolingHead):
+ module.probe.data.normal_()
+ elif isinstance(module, Idefics2PerceiverResampler):
+ module.latents.data.fill_(1.0)
+
+
+@auto_docstring(
+ custom_intro="""
+ Idefics2 vision encoder model that returnss raw image embeddings.
+ """
+)
+class Idefics2VisionTransformer(Idefics2PreTrainedModel):
+ config: Idefics2VisionConfig
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _can_record_outputs = {
+ "hidden_states": Idefics2EncoderLayer,
+ "attentions": Idefics2VisionAttention,
+ }
+
+ def __init__(self, config: Idefics2VisionConfig):
+ super().__init__(config)
+ embed_dim = config.hidden_size
+
+ self.config = config
+ self.embeddings = Idefics2VisionEmbeddings(config)
+ self.encoder = Idefics2Encoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings = value
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values,
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutput]:
+ r"""
+ patch_attention_mask (`torch.BoolTensor` of shape `(batch_size, num_patches_height, num_patches_width)`, *optional*):
+ The attention mask for the patches.
+ """
+ batch_size = pixel_values.size(0)
+ if patch_attention_mask is None:
+ patch_size = self.config.patch_size
+ patch_attention_mask = torch.ones(
+ (
+ batch_size,
+ pixel_values.size(2) // patch_size,
+ pixel_values.size(3) // patch_size,
+ )
+ )
+ patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
+
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
+
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
+ if not torch.any(~patch_attention_mask):
+ patch_attention_mask = None
+ elif self.config._attn_implementation != "flash_attention_2":
+ patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=patch_attention_mask,
+ **kwargs,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ return BaseModelOutput(last_hidden_state=last_hidden_state)
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics2
+class Idefics2RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Idefics2RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Idefics2PerceiverAttention(nn.Module):
+ def __init__(self, config, layer_idx: Optional[int] = None) -> None:
+ """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
+ super().__init__()
+ self.config = config
+ self.layer_idx = None
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.resampler_n_heads
+ self.head_dim = config.resampler_head_dim
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.attention_dropout = config.attention_dropout
+ self.scaling = self.head_dim**-0.5
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ self.is_causal = False
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ latents: torch.Tensor,
+ context: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
+
+ Args:
+ latents (`torch.Tensor`): Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
+ context (`torch.Tensor`): Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
+ attention_mask (`torch.Tensor`, *optional*): Tensor of shape [bsz, 1, seq, n_latents] representing attention mask.
+ position_ids (`torch.LongTensor`, *optional*): Tensor of shape [bsz, seq] representing position indices of each input token.
+ past_key_values (`Cache`, *optional*): Tuple of tensors containing cached key and value states.
+ output_attentions (`bool`, *optional*, defaults to `False`): Whether to return attention weights.
+ use_cache (`bool`, *optional*, defaults to `False`): Whether to use past_key_values for caching.
+ """
+ bsz, q_len, _ = latents.size()
+ kv_seq_len = q_len + context.size()[1]
+
+ hidden_states = torch.concat([context, latents], dim=-2)
+
+ queries = self.q_proj(latents)
+ keys = self.k_proj(hidden_states)
+ values = self.v_proj(hidden_states)
+
+ queries = queries.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ keys = keys.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ values = values.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ past_key_values = getattr(self, "past_key_values", past_key_values)
+
+ if past_key_values is not None:
+ keys, values = past_key_values.update(keys, values, self.layer_idx)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ queries,
+ keys,
+ values,
+ attention_mask,
+ is_causal=self.is_causal,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class Idefics2PerceiverLayer(nn.Module):
+ def __init__(self, config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.n_latents = config.resampler_n_latents
+ self.depth = config.resampler_depth
+ self.rms_norm_eps = config.rms_norm_eps
+
+ self.input_latents_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
+ self.input_context_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
+ self.self_attn = Idefics2PerceiverAttention(config, layer_idx=layer_idx)
+ self.post_attention_layernorm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
+ self.mlp = Idefics2MLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.hidden_size * 4,
+ output_size=config.hidden_size,
+ hidden_act=config.hidden_act,
+ )
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ latents: torch.Tensor,
+ context: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.FloatTensor:
+ """
+ Args:
+ latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ """
+ residual = latents
+
+ latents = self.input_latents_norm(latents)
+ context = self.input_context_norm(context)
+
+ latents, _ = self.self_attn(
+ latents=latents,
+ context=context,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ latents = residual + latents
+ residual = latents
+
+ latents = self.post_attention_layernorm(latents)
+ latents = self.mlp(latents)
+ latents = residual + latents
+
+ return latents
+
+
+@auto_docstring(
+ custom_intro="""
+ Idefics2 perceiver resampler model that performs `depth` blocks of cross-attention with a fixed
+ """
+)
+class Idefics2PerceiverResampler(Idefics2PreTrainedModel):
+ config: Idefics2PerceiverConfig
+ _supports_sdpa = True
+ _supports_flash_attention_2 = True
+ _supports_flex_attn = True
+
+ def __init__(self, config) -> None:
+ super().__init__(config)
+ self.hidden_size = config.hidden_size
+ self.hidden_act = config.hidden_act
+ self.n_latents = config.resampler_n_latents
+ self.depth = config.resampler_depth
+ self.rms_norm_eps = config.rms_norm_eps
+
+ # Create Latents for Perceiver
+ self.latents = nn.Parameter(torch.ones(self.n_latents, self.hidden_size))
+
+ # Create Transformer Blocks
+ self.layers = nn.ModuleList([Idefics2PerceiverLayer(config, idx) for idx in range(self.depth)])
+ self.norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
+
+ @auto_docstring
+ def forward(
+ self,
+ context: torch.Tensor,
+ attention_mask: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ r"""
+ context (`torch.FloatTensor` of shape `(batch, seq_len, embed_dim)`):
+ Input to the layer.
+ """
+ # seq embed -> bsz seq embed
+ latents = self.latents.unsqueeze(0).expand((context.shape[0], *self.latents.size()))
+
+ latent_attention_mask = torch.ones(
+ (attention_mask.size(0), latents.size(1)), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
+ attention_mask = (
+ _prepare_4d_attention_mask(attention_mask, latents.dtype, tgt_len=self.n_latents)
+ if self.config._attn_implementation != "flash_attention_2"
+ else attention_mask
+ )
+
+ compressed_context = latents
+ for perceiver_layer in self.layers:
+ compressed_context = perceiver_layer(
+ compressed_context,
+ context,
+ attention_mask=attention_mask,
+ position_ids=None,
+ **kwargs,
+ )
+
+ compressed_context = self.norm(compressed_context)
+
+ return compressed_context
+
+
+class Idefics2Connector(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.modality_projection = Idefics2MLP(
+ hidden_size=config.vision_config.hidden_size,
+ intermediate_size=config.text_config.intermediate_size,
+ output_size=config.text_config.hidden_size,
+ hidden_act=config.text_config.hidden_act,
+ )
+ self.perceiver_resampler = Idefics2PerceiverResampler._from_config(config.perceiver_config)
+
+ def forward(self, image_hidden_states, attention_mask):
+ image_hidden_states = self.modality_projection(image_hidden_states)
+ image_hidden_states = self.perceiver_resampler(context=image_hidden_states, attention_mask=attention_mask)
+ return image_hidden_states
+
+
+@auto_docstring(
+ custom_intro="""
+ Idefics2 model consisting of a SIGLIP vision encoder and Mistral language decoder
+ """
+)
+class Idefics2Model(Idefics2PreTrainedModel):
+ def __init__(self, config: Idefics2Config):
+ super().__init__(config)
+ self.padding_idx = self.config.text_config.pad_token_id
+ self.vocab_size = self.config.text_config.vocab_size
+
+ self.vision_model = Idefics2VisionTransformer._from_config(config.vision_config)
+ self.connector = Idefics2Connector(config)
+ self.text_model = AutoModel.from_config(config.text_config)
+
+ self.image_seq_len = config.perceiver_config.resampler_n_latents
+ self.image_token_id = self.config.image_token_id
+
+ self.post_init()
+
+ def enable_input_require_grads(self):
+ """
+ Enables the gradients for the input embeddings.
+
+ This is useful for lora when using gradient checkpointing.
+ c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
+
+ Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
+ """
+
+ def get_lowest_module(module):
+ if len(list(module.children())) == 0:
+ # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
+ return module
+ else:
+ # Recursively call the function on each child module
+ return get_lowest_module(list(module.children())[0])
+
+ def make_inputs_require_grads(module, input, output):
+ output.requires_grad_(True)
+
+ self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
+ self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
+ make_inputs_require_grads
+ )
+
+ def disable_input_require_grads(self):
+ self._text_require_grads_hook.remove()
+ self._vision_require_grads_hook.remove()
+
+ def get_input_embeddings(self):
+ return self.text_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.text_model.set_input_embeddings(value)
+
+ def inputs_merger(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: Optional[torch.Tensor],
+ image_hidden_states: Optional[torch.Tensor],
+ ):
+ """
+ This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
+ The merging happens as follows:
+ - The text token sequence is: `tok_1 tok_2 tok_3 ... tok_4`.
+ - We get the image hidden states for the image through the vision encoder (and potentially the perceiver), and that hidden state is then projected into the text embedding space.
+ We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
+ - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
+ - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
+ return inputs_embeds
+
+ def get_image_features(
+ self, pixel_values: torch.FloatTensor, pixel_attention_mask: Optional[torch.LongTensor] = None
+ ):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ pixel_attention_mask (`torch.LongTensor`, *optional*):
+ The attention mask indicating padded regions in the image.
+ """
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
+ pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
+
+ # Remove padding images - padding images are full 0.
+ nb_values_per_image = pixel_values.shape[1:].numel()
+ real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
+ pixel_values = pixel_values[real_images_inds].contiguous()
+
+ # Handle the vision attention mask
+ if pixel_attention_mask is None:
+ pixel_attention_mask = torch.ones(
+ size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
+ dtype=torch.bool,
+ device=pixel_values.device,
+ )
+ else:
+ # Remove padding images from the mask/pP p
+ pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
+ pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
+
+ patch_size = self.config.vision_config.patch_size
+ patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
+ patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) == patch_size * patch_size).bool()
+ # Get sequence from the vision encoder
+ image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
+ image_hidden_states = image_hidden_states.last_hidden_state
+
+ # Modality projection & resampling
+ image_hidden_states = self.connector(
+ image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
+ )
+ image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1])
+ return image_hidden_states
+
+ @can_return_tuple
+ @auto_docstring(
+ custom_intro="""
+ Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
+ the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
+ max_num_images is the maximum number of images among the batch_size samples in the batch.
+
+ Padding images are not needed beyond padding the pixel_values at the entrance of the model.
+ For efficiency, we only pass through the vision_model's forward the real images by
+ discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
+ image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
+ """
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
+ image_hidden_states: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, Idefics2BaseModelOutputWithPast]:
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The hidden states of the image encoder after modality projection and perceiver resampling.
+ """
+
+ if self.training and self.text_model.gradient_checkpointing and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
+
+ # START VISUAL INPUTS INTEGRATION
+ if pixel_values is not None and image_hidden_states is not None:
+ raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
+ elif pixel_values is not None:
+ image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
+ elif image_hidden_states is not None:
+ image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
+
+ if image_hidden_states is not None:
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
+ # that simply don't exist
+ inputs_embeds = self.inputs_merger(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ image_hidden_states=image_hidden_states,
+ )
+
+ kwargs["return_dict"] = True
+ outputs = self.text_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return Idefics2BaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Idefics2 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top.
+ """
+)
+class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Idefics2Model(config)
+ self.image_token_id = self.config.image_token_id
+
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.vocab_size = config.text_config.vocab_size
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def enable_input_require_grads(self):
+ """
+ Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
+ the model weights fixed.
+ """
+
+ def make_inputs_require_grads(module, input, output):
+ output.requires_grad_(True)
+
+ self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
+ self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook(
+ make_inputs_require_grads
+ )
+
+ def disable_input_require_grads(self):
+ self._text_require_grads_hook.remove()
+ self._vision_require_grads_hook.remove()
+
+ def get_input_embeddings(self):
+ return self.model.text_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.text_model.set_input_embeddings(value)
+
+ def get_image_features(
+ self, pixel_values: torch.FloatTensor, pixel_attention_mask: Optional[torch.LongTensor] = None
+ ):
+ return self.model.get_image_features(pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
+ image_hidden_states: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Idefics2CausalLMOutputWithPast]:
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The hidden states of the image encoder after modality projection and perceiver resampling.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics2ForConditionalGeneration`).
+ Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
+ computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> import requests
+ >>> import torch
+ >>> from PIL import Image
+ >>> from io import BytesIO
+
+ >>> from transformers import AutoProcessor, AutoModelForVision2Seq
+ >>> from transformers.image_utils import load_image
+
+ >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
+ >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
+ >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
+ >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
+
+ >>> processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b-base")
+ >>> model = AutoModelForVision2Seq.from_pretrained("HuggingFaceM4/idefics2-8b-base", device_map="auto")
+
+ >>> BAD_WORDS_IDS = processor.tokenizer(["", ""], add_special_tokens=False).input_ids
+ >>> EOS_WORDS_IDS = [processor.tokenizer.eos_token_id]
+
+ >>> # Create inputs
+ >>> prompts = [
+ ... "In this image, we can see the city of New York, and more specifically the Statue of Liberty.In this image,",
+ ... "In which city is that bridge located?",
+ ... ]
+ >>> images = [[image1, image2], [image3]]
+ >>> inputs = processor(images=images, text=prompts, padding=True, return_tensors="pt").to("cuda")
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_new_tokens=20)
+ >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
+
+ >>> print(generated_texts)
+ ['In this image, we can see the city of New York, and more specifically the Statue of Liberty. In this image, we can see the city of New York, and more specifically the Statue of Liberty.\n\n', 'In which city is that bridge located?\n\nThe bridge is located in the city of Pittsburgh, Pennsylvania.\n\n\nThe bridge is']
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ pixel_values=pixel_values,
+ pixel_attention_mask=pixel_attention_mask,
+ image_hidden_states=image_hidden_states,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ return_dict=True,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return Idefics2CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ pixel_values=None,
+ pixel_attention_mask=None,
+ image_hidden_states=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
+ # precedence is moved to the model, we can remove this fn)
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ pixel_values=pixel_values,
+ pixel_attention_mask=pixel_attention_mask,
+ image_hidden_states=image_hidden_states,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if image_hidden_states is not None or cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_attention_mask"] = None
+
+ return model_inputs
+
+
+__all__ = ["Idefics2ForConditionalGeneration", "Idefics2PreTrainedModel", "Idefics2Model"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/processing_idefics2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/processing_idefics2.py
new file mode 100644
index 0000000000000000000000000000000000000000..550ca877409504ac3a608e5f8ce99259cf46b501
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/idefics2/processing_idefics2.py
@@ -0,0 +1,261 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for IDEFICS2.
+"""
+
+from itertools import accumulate
+from typing import TYPE_CHECKING, Optional, Union
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput, is_valid_image, load_image
+from ...processing_utils import (
+ ImagesKwargs,
+ ProcessingKwargs,
+ ProcessorMixin,
+ Unpack,
+)
+from ...tokenization_utils_base import AddedToken, TextInput
+from ...utils import logging
+
+
+if TYPE_CHECKING:
+ from ...tokenization_utils_base import PreTokenizedInput
+
+
+logger = logging.get_logger(__name__)
+
+
+def is_url(val) -> bool:
+ return isinstance(val, str) and val.startswith("http")
+
+
+def is_image_or_image_url(elem):
+ return is_url(elem) or is_valid_image(elem)
+
+
+class Idefics2ImagesKwargs(ImagesKwargs, total=False):
+ image_seq_len: Optional[int]
+
+
+class Idefics2ProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs: Idefics2ImagesKwargs
+
+ _defaults = {
+ "text_kwargs": {
+ "add_special_tokens": True,
+ "padding": False,
+ "is_split_into_words": False,
+ },
+ "images_kwargs": {},
+ }
+
+
+class Idefics2Processor(ProcessorMixin):
+ r"""
+ Constructs a IDEFICS2 processor which wraps a LLama tokenizer and IDEFICS2 image processor into a single processor.
+
+ [`IdeficsProcessor`] offers all the functionalities of [`Idefics2ImageProcessor`] and [`LlamaTokenizerFast`]. See
+ the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information.
+
+ Args:
+ image_processor (`Idefics2ImageProcessor`):
+ An instance of [`Idefics2ImageProcessor`]. The image processor is a required input.
+ tokenizer (`PreTrainedTokenizerBase`, *optional*):
+ An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
+ image_seq_len (`int`, *optional*, defaults to 64):
+ The length of the image sequence i.e. the number of tokens per image in the input.
+ This parameter is used to build the string from the input prompt and image tokens and should match the
+ config.perceiver_config.resampler_n_latents value for the model used.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "Idefics2ImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self, image_processor, tokenizer=None, image_seq_len: int = 64, chat_template: Optional[str] = None, **kwargs
+ ):
+ if not hasattr(tokenizer, "image_token"):
+ self.fake_image_token = AddedToken("", normalized=False, special=True).content
+ self.image_token = AddedToken("", normalized=False, special=True).content
+ tokens_to_add = {"additional_special_tokens": [self.fake_image_token, self.image_token]}
+ tokenizer.add_special_tokens(tokens_to_add)
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
+ else:
+ self.fake_image_token = tokenizer.image_boundary_token
+ self.image_token = tokenizer.image_token
+ self.image_token_id = tokenizer.image_token_id
+
+ self.end_of_utterance_token = AddedToken("", normalized=False, special=True)
+ tokenizer.add_special_tokens({"additional_special_tokens": [self.end_of_utterance_token]})
+ self.image_seq_len = image_seq_len
+
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def _extract_images_from_prompts(self, prompts):
+ prompt_images = []
+ for prompt in prompts:
+ images = []
+ for elem in prompt:
+ if is_valid_image(elem):
+ images.append(elem)
+ elif is_url(elem):
+ images.append(load_image(elem))
+ prompt_images.append(images)
+ return prompt_images
+
+ def __call__(
+ self,
+ images: Union[ImageInput, list[ImageInput], list[list[ImageInput]]] = None,
+ text: Union[TextInput, "PreTokenizedInput", list[TextInput], list["PreTokenizedInput"]] = None,
+ audio=None,
+ videos=None,
+ **kwargs: Unpack[Idefics2ProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Processes the input prompts and returns a BatchEncoding.
+
+ Example:
+
+ ```python
+ >>> import requests
+ >>> from transformers import Idefics2Processor
+ >>> from transformers.image_utils import load_image
+
+ >>> processor = Idefics2Processor.from_pretrained("HuggingFaceM4/idefics2-8b", image_seq_len=2)
+ >>> processor.image_processor.do_image_splitting = False # Force as False to simplify the example
+
+ >>> url1 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
+ >>> url2 = "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg"
+
+ >>> image1, image2 = load_image(url1), load_image(url2)
+ >>> images = [[image1], [image2]]
+
+ >>> text = [
+ ... "In this image, we see",
+ ... "bla bla bla",
+ ... ]
+ >>> outputs = processor(images=images, text=text, return_tensors="pt", padding=True)
+ >>> input_ids = outputs.input_ids
+ >>> input_tokens = processor.tokenizer.batch_decode(input_ids)
+ >>> print(input_tokens)
+ [' In this image, we see', ' bla bla bla']
+ ```
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`, *optional*):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. If is of type `list[ImageInput]`, it's assumed that this is for a single prompt i.e. of batch size 1.
+ text (`Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]`, *optional*):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+
+ Wherever an image token, `` is encountered it is expanded to
+ `` + `` * `image_seq_len` * `.
+ return_tensors (`Union[str, TensorType]`, *optional*):
+ If set, will return tensors of a particular framework. See [`PreTrainedTokenizerFast.__call__`] for more
+ information.
+
+ """
+ if text is None and images is None:
+ raise ValueError("You must provide either `text` or `images`.")
+
+ output_kwargs = self._merge_kwargs(
+ Idefics2ProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ image_seq_len = output_kwargs["images_kwargs"].pop("image_seq_len", None)
+ image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+
+ n_images_in_text = []
+ inputs = {}
+
+ if text is not None:
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list) and not isinstance(text[0], str):
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+
+ # Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
+ fake_image_token = self.fake_image_token
+ image_token = self.image_token
+ image_str = f"{fake_image_token}{image_token * image_seq_len}{fake_image_token}"
+
+ if self.image_processor.do_image_splitting:
+ # A single image token is split into 4 patches + 1 original image
+ image_str = image_str * 5
+ image_seq_len *= 5
+
+ prompt_strings = []
+ for sample in text:
+ n_images_in_text.append(sample.count(image_token))
+ sample = sample.replace(image_token, image_str)
+ # Remove any double fake tokens if images are adjacent
+ sample = sample.replace(f"{fake_image_token}{fake_image_token}", f"{fake_image_token}")
+ prompt_strings.append(sample)
+
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
+ self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
+ inputs.update(text_inputs)
+
+ if images is not None:
+ if is_image_or_image_url(images):
+ images = [[images]]
+ elif isinstance(images, (list, tuple)) and is_image_or_image_url(images[0]):
+ if text is not None:
+ if sum(n_images_in_text) != len(images):
+ raise ValueError(
+ f"The total number of {image_token} tokens in the prompts should be the same as the number of images passed."
+ f" Found {sum(n_images_in_text)} {image_token} tokens and {len(images)} images."
+ )
+ # Reorganize the images to match the prompts
+ cumsum_images_in_text = [0] + list(accumulate(n_images_in_text))
+ images = [
+ images[cumsum_images_in_text[i] : cumsum_images_in_text[i + 1]]
+ for i in range(len(n_images_in_text))
+ ]
+ else:
+ images = [images]
+
+ elif (
+ not isinstance(images, (list, tuple))
+ and not isinstance(images[0], (list, tuple))
+ and not is_image_or_image_url(images[0][0])
+ ):
+ raise ValueError(
+ "Invalid input images. Please provide a single image or a list of images or a list of list of images."
+ )
+
+ n_images_in_images = [len(sample) for sample in images]
+ if text is not None and not n_images_in_images == n_images_in_text:
+ raise ValueError(
+ f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
+ )
+
+ # Load images if they are URLs
+ images = [[load_image(im) for im in sample] for sample in images]
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
+ inputs.update(image_inputs)
+
+ return BatchFeature(inputs, tensor_type=return_tensors)
+
+
+__all__ = ["Idefics2Processor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/jetmoe/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/jetmoe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7058590acc8859f934330ced9556d6ca66b50a51
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/jetmoe/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_jetmoe import *
+ from .modeling_jetmoe import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/jetmoe/configuration_jetmoe.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/jetmoe/configuration_jetmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..118e734143f42ecff966882ad8f77814bf691f21
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/jetmoe/configuration_jetmoe.py
@@ -0,0 +1,153 @@
+# coding=utf-8
+# Copyright 2024 JetMoe AI and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""JetMoe model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class JetMoeConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`JetMoeModel`]. It is used to instantiate a
+ JetMoe model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a configuration of the JetMoe-4B.
+
+ [jetmoe/jetmoe-8b](https://huggingface.co/jetmoe/jetmoe-8b)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the JetMoe model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`JetMoeModel`]
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each key and value in the Transformer encoder.
+ kv_channels (`int`, *optional*, defaults to 128):
+ Defines the number of channels for the key and value tensors.
+ intermediate_size (`int`, *optional*, defaults to 5632):
+ Dimension of the MLP representations.
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
+ The maximum sequence length that this model might ever be used with. JetMoe's attention allows sequence of
+ up to 4096 tokens.
+ activation_function (`string`, *optional*, defaults to `"silu"`):
+ Defines the activation function for MLP experts.
+ num_local_experts (`int`, *optional*, defaults to 8):
+ Defines the number of experts in the MoE and MoA.
+ num_experts_per_tok (`int, *optional*, defaults to 2):
+ The number of experts to route per-token and for MoE and MoA.
+ output_router_logits (`bool`, *optional*, defaults to `False`):
+ Whether or not the router logits should be returned by the model. Enabling this will also
+ allow the model to output the auxiliary loss.
+ aux_loss_coef (`float`, *optional*, defaults to 0.01):
+ The coefficient for the auxiliary loss.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the "beginning-of-sequence" token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the "end-of-sequence" token.
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ initializer_range (`float`, *optional*, defaults to 0.01):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+
+ ```python
+ >>> from transformers import JetMoeModel, JetMoeConfig
+
+ >>> # Initializing a JetMoe 4B style configuration
+ >>> configuration = JetMoeConfig()
+
+ >>> # Initializing a model from the JetMoe 4B style configuration
+ >>> model = JetMoeModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "jetmoe"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"head_dim": "kv_channels"}
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=2048,
+ num_hidden_layers=12,
+ num_key_value_heads=16,
+ kv_channels=128,
+ intermediate_size=5632,
+ max_position_embeddings=4096,
+ activation_function="silu",
+ num_local_experts=8,
+ num_experts_per_tok=2,
+ output_router_logits=False,
+ aux_loss_coef=0.01,
+ use_cache=True,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=True,
+ rope_theta=10000.0,
+ rms_norm_eps=1e-6,
+ initializer_range=0.01,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ if num_experts_per_tok > num_local_experts:
+ raise ValueError("`num_experts_per_tok` must be less than or equal to `num_local_experts`")
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_key_value_heads * num_experts_per_tok
+ self.num_key_value_heads = num_key_value_heads
+ self.kv_channels = kv_channels
+ self.intermediate_size = intermediate_size
+ self.max_position_embeddings = max_position_embeddings
+ self.activation_function = activation_function
+ self.num_local_experts = num_local_experts
+ self.num_experts_per_tok = num_experts_per_tok
+ self.output_router_logits = output_router_logits
+ self.aux_loss_coef = aux_loss_coef
+ self.use_cache = use_cache
+ self.initializer_range = initializer_range
+ self.attention_dropout = attention_dropout
+
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+
+ self.rope_theta = rope_theta
+ self.rms_norm_eps = rms_norm_eps
+
+ super().__init__(
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
+ )
+
+
+__all__ = ["JetMoeConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/jetmoe/modeling_jetmoe.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/jetmoe/modeling_jetmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ca0a9a4366950ed07c5e966f7e632a1e5ee697a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/jetmoe/modeling_jetmoe.py
@@ -0,0 +1,1203 @@
+# coding=utf-8
+# Copyright 2024 JetMoe AI and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch JetMoe model."""
+
+import math
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_layers import (
+ GenericForSequenceClassification,
+ GradientCheckpointingLayer,
+)
+from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_jetmoe import JetMoeConfig
+
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+if is_flash_attn_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func
+def load_balancing_loss_func(
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
+ num_experts: Optional[int] = None,
+ top_k=2,
+ attention_mask: Optional[torch.Tensor] = None,
+) -> Union[torch.Tensor, int]:
+ r"""
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+
+ Args:
+ gate_logits:
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+ shape [batch_size X sequence_length, num_experts].
+ num_experts:
+ Number of experts
+ top_k:
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
+ parameter.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention_mask used in forward function
+ shape [batch_size X sequence_length] if not None.
+
+ Returns:
+ The auxiliary loss.
+ """
+ if gate_logits is None or not isinstance(gate_logits, tuple):
+ return 0
+
+ if isinstance(gate_logits, tuple):
+ compute_device = gate_logits[0].device
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
+
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
+
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+ if attention_mask is None:
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
+ else:
+ batch_size, sequence_length = attention_mask.shape
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+ expert_attention_mask = (
+ attention_mask[None, :, :, None, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+ .reshape(-1, top_k, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+ expert_attention_mask, dim=0
+ )
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+ router_per_expert_attention_mask = (
+ attention_mask[None, :, :, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1]))
+ .reshape(-1, routing_weights.shape[1])
+ .to(compute_device)
+ )
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+ router_per_expert_attention_mask, dim=0
+ )
+
+ device_index = routing_weights.device.index if routing_weights.device.index is not None else 0
+ rank = routing_weights.shape[1] * int(device_index)
+ overall_loss = torch.sum(
+ tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
+ )
+ return overall_loss * num_experts
+
+
+class JetMoeParallelExperts(nn.Module):
+ def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
+ """
+ Initialize the JetMoeParallelExperts module.
+ The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with
+ many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and
+ [ScatterMoE](https://github.com/shawntan/scattermoe), as well as the
+ [MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py)
+ used in vllm.
+
+ Args:
+ num_experts (int):
+ Number of experts.
+ input_size (int):
+ Size of the input.
+ output_size (int):
+ Size of the output.
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
+ self.num_experts = num_experts
+ self.input_size = input_size
+ self.output_size = output_size
+
+ def forward(self, inputs, expert_size):
+ """
+ Forward pass of the JetMoeParallelExperts module.
+
+ Args:
+ inputs (Tensor):
+ Input tensor.
+ expert_size:
+ Expert size information.
+
+ Returns:
+ Tensor: Output tensor.
+ """
+ input_list = inputs.split(expert_size, dim=0)
+ output_list = []
+ for i in range(self.num_experts):
+ output_list.append(F.linear(input_list[i], self.weight[i]))
+ results = torch.cat(output_list, dim=0)
+ return results
+
+
+class JetMoeTopKGating(nn.Module):
+ def __init__(self, input_size: int, num_experts: int, top_k: int):
+ """
+ Initialize the top-k gating mechanism.
+
+ Args:
+ input_size (`int`):
+ Size of the input.
+ num_experts (`int`):
+ Number of experts.
+ top_k (`int`):
+ Number of top experts to select.
+ """
+ super().__init__()
+
+ self.num_experts = num_experts
+ self.input_size = input_size
+ self.top_k = top_k
+
+ self.layer = nn.Linear(input_size, num_experts, bias=False)
+
+ def forward(self, hidden_states):
+ # compute the top_k routing decision
+ logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
+ top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
+ top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
+
+ # compute number of input given to each expert
+ zeros = torch.zeros(
+ [top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
+ ) # [num_tokens, num_experts]
+ gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
+ expert_size = gates.long().sum(0) # [num_experts,]
+ # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
+ # (and `DataDependentOutputException`)
+ expert_size = expert_size.tolist()
+
+ # sort and group input tokens according to expert assignment
+ top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
+ _, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
+ batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]
+
+ # gather the gate values for grouped input tokens
+ top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
+ batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]
+
+ return index_sorted_experts, batch_index, batch_gates, expert_size, logits
+
+
+class JetMoeMoE(nn.Module):
+ """
+ A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
+
+ Args:
+ config:
+ Configuration object with model hyperparameters.
+ """
+
+ def __init__(self, config: JetMoeConfig):
+ super().__init__()
+
+ self.input_size = config.hidden_size
+ self.hidden_size = config.intermediate_size
+ self.activation = ACT2FN[config.activation_function]
+ self.bias = torch.nn.Parameter(torch.empty(self.input_size))
+ self.input_linear = JetMoeParallelExperts(config.num_local_experts, self.input_size, self.hidden_size * 2)
+ self.output_linear = JetMoeParallelExperts(config.num_local_experts, self.hidden_size, self.input_size)
+
+ self.router = JetMoeTopKGating(
+ input_size=self.input_size,
+ num_experts=config.num_local_experts,
+ top_k=config.num_experts_per_tok,
+ )
+
+ def forward(self, layer_input):
+ """
+ Forward pass of the mixture of experts layer.
+
+ Args:
+ layer_input (Tensor):
+ Input tensor.
+
+ Returns:
+ Tensor:
+ Output tensor.
+ Tensor:
+ Router logits.
+ """
+ bsz, length, emb_size = layer_input.size()
+ layer_input = layer_input.reshape(-1, emb_size)
+ _, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)
+
+ expert_inputs = layer_input[batch_index]
+ hidden_states = self.input_linear(expert_inputs, expert_size)
+ chunked_hidden_states = hidden_states.chunk(2, dim=-1)
+ hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
+ expert_outputs = self.output_linear(hidden_states, expert_size)
+
+ expert_outputs = expert_outputs * batch_gates[:, None]
+
+ zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
+ layer_output = zeros.index_add(0, batch_index, expert_outputs)
+ layer_output = layer_output.view(bsz, length, self.input_size)
+ layer_output = layer_output + self.bias
+ return layer_output, router_logits
+
+
+class JetMoeMoA(nn.Module):
+ """
+ A Sparsely gated mixture of attention layer with pairs of query- and output-projections as experts.
+
+ Args:
+ config:
+ Configuration object with model hyperparameters.
+ """
+
+ def __init__(self, config: JetMoeConfig):
+ super().__init__()
+
+ self.num_experts = config.num_local_experts
+ self.input_size = config.hidden_size
+ self.hidden_size = config.kv_channels * config.num_key_value_heads
+ self.top_k = config.num_experts_per_tok
+ self.bias = torch.nn.Parameter(torch.empty(self.input_size))
+
+ self.input_linear = JetMoeParallelExperts(self.num_experts, self.input_size, self.hidden_size)
+ self.output_linear = JetMoeParallelExperts(self.num_experts, self.hidden_size, self.input_size)
+
+ self.router = JetMoeTopKGating(
+ input_size=self.input_size,
+ num_experts=self.num_experts,
+ top_k=self.top_k,
+ )
+
+ def map(self, layer_input):
+ """
+ Map inputs to attention experts according to routing decision and compute query projection inside each experts.
+ """
+
+ # Compute gating topology
+ bsz, length, emb_size = layer_input.size()
+ layer_input = layer_input.reshape(-1, emb_size) # [bsz * length, emb_size]
+ index_sorted_experts, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)
+ topo_info = (index_sorted_experts, batch_index, batch_gates, expert_size)
+
+ # Group inputs according to topology and compute query projection
+ expert_inputs = layer_input[batch_index] # [bsz * length * top_k, emb_size]
+ expert_outputs = self.input_linear(expert_inputs, expert_size) # [bsz * length * top_k, hidden_size]
+
+ # Ungroup queries back to original order
+ zeros = torch.zeros(
+ (bsz * length * self.top_k, self.hidden_size), dtype=expert_outputs.dtype, device=expert_outputs.device
+ )
+ layer_output = zeros.index_add(0, index_sorted_experts, expert_outputs)
+ layer_output = layer_output.view(bsz, length, self.top_k, -1) # [bsz, length, top_k, hidden_size]
+ return layer_output, router_logits, topo_info
+
+ def reduce(self, layer_input, topo_info):
+ """
+ Compute output projection inside each attention experts and merge the outputs of different experts.
+ """
+ bsz, length, k, hidden_size = layer_input.size()
+ layer_input = layer_input.reshape(-1, hidden_size) # [bsz * length * k, hidden_size]
+ index_sorted_experts, batch_index, batch_gates, expert_size = topo_info
+
+ # Group inputs according to topology and compute output projection
+ expert_inputs = layer_input[index_sorted_experts] # [bsz * length * top_k, hidden_size]
+ expert_outputs = self.output_linear(expert_inputs, expert_size) # [bsz * length * top_k, emb_size]
+
+ # Apply gates to attention expert outputs
+ expert_outputs = expert_outputs * batch_gates[:, None]
+
+ # Ungroup and merge outputs to original order
+ zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
+ layer_output = zeros.index_add(0, batch_index, expert_outputs)
+ layer_output = layer_output.view(bsz, length, self.input_size)
+ layer_output = layer_output + self.bias
+ return layer_output
+
+ def forward(self, layer_input):
+ raise NotImplementedError("This module doesn't support call and forward.")
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->JetMoe
+class JetMoeRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ JetMoeRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->JetMoe
+class JetMoeRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: JetMoeConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class JetMoeAttention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper.
+ """
+
+ def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None):
+ """
+ Initialize the JetMoeAttention module.
+
+ Args:
+ config:
+ Configuration object with model hyperparameters.
+ layer_idx:
+ Index of the layer in the model.
+ """
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.is_causal = True
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.top_k = config.num_experts_per_tok
+ self.attention_dropout = config.attention_dropout
+ self.kv_projection_size = config.kv_channels * config.num_key_value_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.kv_channels
+
+ self.experts = JetMoeMoA(config)
+
+ self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False)
+
+ self.rotary_emb = JetMoeRotaryEmbedding(config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states, router_logits, topo_info = self.experts.map(hidden_states)
+ key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads for top-k attention experts
+ key_states = key_states.repeat(1, self.top_k, 1, 1)
+ value_states = value_states.repeat(1, self.top_k, 1, 1)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size)
+
+ attn_output = self.experts.reduce(attn_output, topo_info)
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, router_logits
+
+
+class JetMoeSdpaAttention(JetMoeAttention):
+ """
+ JetMoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `JetMoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from JetMoeAttention.forward
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]], Optional[torch.Tensor]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "JetMoeModel is using JetMoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states, router_logits, topo_info = self.experts.map(hidden_states)
+ key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads for top-k attention experts
+ key_states = key_states.repeat(1, self.top_k, 1, 1)
+ value_states = value_states.repeat(1, self.top_k, 1, 1)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = causal_mask is None and q_len > 1
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size)
+
+ attn_output = self.experts.reduce(attn_output, topo_info)
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ return attn_output, None, router_logits
+
+
+class JetMoeFlashAttention2(JetMoeAttention):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: Optional[torch.FloatTensor],
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[
+ tuple[torch.Tensor, tuple[torch.Tensor]],
+ Optional[tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]]],
+ ]:
+ """
+ Forward pass of the JetMoeAttention module.
+
+ Args:
+ hidden_states (Optional[torch.FloatTensor]): Input hidden states.
+ attention_mask (Optional[torch.FloatTensor]): Attention mask.
+ layer_past (Optional[tuple[torch.Tensor]]): Past layer state.
+ use_cache (Optional[bool]): Whether to use cached states.
+ output_attentions (Optional[bool]): Whether to output attention weights.
+ cache_position (Optional[torch.LongTensor]): Position of the cache.
+
+ Returns:
+ Union[tuple[torch.Tensor, tuple[torch.Tensor]], Optional[tuple[...]]]: Tuple containing outputs.
+ """
+ output_attentions = False
+ bsz, q_len, hidden_size = hidden_states.size()
+
+ # calculate query, key, values
+ query_states, router_logits, topo_info = self.experts.map(hidden_states)
+ key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads for top-k attention experts
+ key_states = key_states.repeat(1, self.top_k, 1, 1)
+ value_states = value_states.repeat(1, self.top_k, 1, 1)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = (
+ torch.get_autocast_dtype(device_type)
+ if hasattr(torch, "get_autocast_dtype")
+ else torch.get_autocast_gpu_dtype()
+ )
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.kv_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ ).to(input_dtype)
+
+ # output projection
+ attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size)
+ attn_output = self.experts.reduce(attn_output, topo_info)
+ attn_output = attn_output.view(bsz, q_len, hidden_size) # re-assemble all head outputs side by side
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, router_logits
+
+
+JETMOE_ATTENTION_CLASSES = {
+ "eager": JetMoeAttention,
+ "flash_attention_2": JetMoeFlashAttention2,
+ "sdpa": JetMoeSdpaAttention,
+}
+
+
+class JetMoeBlock(GradientCheckpointingLayer):
+ def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None):
+ """
+ Initialize the JetMoeBlock module.
+
+ Args:
+ config:
+ Configuration object with model hyperparameters.
+ """
+ super().__init__()
+ self.input_layernorm = JetMoeRMSNorm(config.hidden_size)
+ self.self_attention = JETMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+ self.post_attention_layernorm = JetMoeRMSNorm(config.hidden_size)
+
+ self.mlp = JetMoeMoE(config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: Optional[torch.FloatTensor],
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_router_logits: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]:
+ # Self Attention
+ attn_output, self_attn_weights, attn_router_logits = self.self_attention(
+ hidden_states=self.input_layernorm(hidden_states),
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = hidden_states + attn_output
+ x_mlp, mlp_router_logits = self.mlp(self.post_attention_layernorm(hidden_states))
+ hidden_states = hidden_states + x_mlp
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if output_router_logits:
+ outputs += attn_router_logits, mlp_router_logits
+
+ return outputs
+
+
+@auto_docstring
+class JetMoePreTrainedModel(PreTrainedModel):
+ config: JetMoeConfig
+ base_model_prefix = "transformer"
+ supports_gradient_checkpointing = False
+ _no_split_modules = ["JetMoeBlock"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, (nn.Linear,)):
+ # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, JetMoeRMSNorm):
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, JetMoeParallelExperts):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, (JetMoeMoA, JetMoeMoE)):
+ module.bias.data.zero_()
+
+
+@auto_docstring
+class JetMoeModel(JetMoePreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`JetMoeBlock`]
+
+ Args:
+ config:
+ JetMoeConfig
+ """
+
+ def __init__(self, config: JetMoeConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList([JetMoeBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
+ self._attn_implementation = config._attn_implementation
+ self.norm = JetMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> MoeModelOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_router_logits = () if output_router_logits else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_router_logits=output_router_logits,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if output_router_logits:
+ all_router_logits += (layer_outputs[-2], layer_outputs[-1])
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ router_logits=all_router_logits,
+ )
+
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, "BlockMask"],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = JetMoeModel(config)
+ self.vocab_size = config.vocab_size
+ self.aux_loss_coef = config.aux_loss_coef
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.tie_word_embeddings = config.tie_word_embeddings
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> MoeCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: MoeModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits,
+ self.num_experts,
+ self.num_experts_per_tok,
+ attention_mask,
+ )
+ if labels is not None:
+ loss += self.aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
+
+
+class JetMoeForSequenceClassification(GenericForSequenceClassification, JetMoePreTrainedModel): ...
+
+
+__all__ = ["JetMoeForCausalLM", "JetMoeModel", "JetMoePreTrainedModel", "JetMoeForSequenceClassification"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f079c33c7157898807c5f405086804cc1533ff9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_layoutlm import *
+ from .modeling_layoutlm import *
+ from .modeling_tf_layoutlm import *
+ from .tokenization_layoutlm import *
+ from .tokenization_layoutlm_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/configuration_layoutlm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/configuration_layoutlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..18bfacb755921ffe922fa14c9f385686b2f286ba
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/configuration_layoutlm.py
@@ -0,0 +1,192 @@
+# coding=utf-8
+# Copyright 2010, The Microsoft Research Asia LayoutLM Team authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""LayoutLM model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+from typing import Any, Optional
+
+from ... import PretrainedConfig, PreTrainedTokenizer
+from ...onnx import OnnxConfig, PatchingSpec
+from ...utils import TensorType, is_torch_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class LayoutLMConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LayoutLMModel`]. It is used to instantiate a
+ LayoutLM model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the LayoutLM
+ [microsoft/layoutlm-base-uncased](https://huggingface.co/microsoft/layoutlm-base-uncased) architecture.
+
+ Configuration objects inherit from [`BertConfig`] and can be used to control the model outputs. Read the
+ documentation from [`BertConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the LayoutLM model. Defines the different tokens that can be represented by the
+ *inputs_ids* passed to the forward method of [`LayoutLMModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed into [`LayoutLMModel`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ The value used to pad input_ids.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ max_2d_position_embeddings (`int`, *optional*, defaults to 1024):
+ The maximum value that the 2D position embedding might ever used. Typically set this to something large
+ just in case (e.g., 1024).
+
+ Examples:
+
+ ```python
+ >>> from transformers import LayoutLMConfig, LayoutLMModel
+
+ >>> # Initializing a LayoutLM configuration
+ >>> configuration = LayoutLMConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = LayoutLMModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "layoutlm"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=0,
+ use_cache=True,
+ max_2d_position_embeddings=1024,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.use_cache = use_cache
+ self.max_2d_position_embeddings = max_2d_position_embeddings
+
+
+class LayoutLMOnnxConfig(OnnxConfig):
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ task: str = "default",
+ patching_specs: Optional[list[PatchingSpec]] = None,
+ ):
+ super().__init__(config, task=task, patching_specs=patching_specs)
+ self.max_2d_positions = config.max_2d_position_embeddings - 1
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("input_ids", {0: "batch", 1: "sequence"}),
+ ("bbox", {0: "batch", 1: "sequence"}),
+ ("attention_mask", {0: "batch", 1: "sequence"}),
+ ("token_type_ids", {0: "batch", 1: "sequence"}),
+ ]
+ )
+
+ def generate_dummy_inputs(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ batch_size: int = -1,
+ seq_length: int = -1,
+ is_pair: bool = False,
+ framework: Optional[TensorType] = None,
+ ) -> Mapping[str, Any]:
+ """
+ Generate inputs to provide to the ONNX exporter for the specific framework
+
+ Args:
+ tokenizer: The tokenizer associated with this model configuration
+ batch_size: The batch size (int) to export the model for (-1 means dynamic axis)
+ seq_length: The sequence length (int) to export the model for (-1 means dynamic axis)
+ is_pair: Indicate if the input is a pair (sentence 1, sentence 2)
+ framework: The framework (optional) the tokenizer will generate tensor for
+
+ Returns:
+ Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
+ """
+
+ input_dict = super().generate_dummy_inputs(
+ tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
+ )
+
+ # Generate a dummy bbox
+ box = [48, 84, 73, 128]
+
+ if not framework == TensorType.PYTORCH:
+ raise NotImplementedError("Exporting LayoutLM to ONNX is currently only supported for PyTorch.")
+
+ if not is_torch_available():
+ raise ValueError("Cannot generate dummy inputs without PyTorch installed.")
+ import torch
+
+ batch_size, seq_length = input_dict["input_ids"].shape
+ input_dict["bbox"] = torch.tensor([*[box] * seq_length]).tile(batch_size, 1, 1)
+ return input_dict
+
+
+__all__ = ["LayoutLMConfig", "LayoutLMOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/modeling_layoutlm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/modeling_layoutlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e71eb7d8fb984a5a34c1f3b681143b7709eb663
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/modeling_layoutlm.py
@@ -0,0 +1,1144 @@
+# coding=utf-8
+# Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch LayoutLM model."""
+
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPooling,
+ MaskedLMOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import auto_docstring, can_return_tuple, logging
+from .configuration_layoutlm import LayoutLMConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+LayoutLMLayerNorm = nn.LayerNorm
+
+
+class LayoutLMEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
+ self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
+ self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
+ self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+
+ def forward(
+ self,
+ input_ids=None,
+ bbox=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ words_embeddings = inputs_embeds
+ position_embeddings = self.position_embeddings(position_ids)
+ try:
+ left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
+ upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
+ right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
+ lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
+ except IndexError as e:
+ raise IndexError("The `bbox`coordinate values should be within 0-1000 range.") from e
+
+ h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
+ w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = (
+ words_embeddings
+ + position_embeddings
+ + left_position_embeddings
+ + upper_position_embeddings
+ + right_position_embeddings
+ + lower_position_embeddings
+ + h_position_embeddings
+ + w_position_embeddings
+ + token_type_embeddings
+ )
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+# Copied from transformers.models.align.modeling_align.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->LayoutLM
+class LayoutLMSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.config = config
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.attention_dropout = config.attention_probs_dropout_prob
+ self.scaling = self.attention_head_size**-0.5
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ **kwargs,
+ ) -> tuple[torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.attention_head_size)
+
+ query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ head_mask=head_mask,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->LayoutLM
+class LayoutLMSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->LayoutLM
+class LayoutLMAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = LayoutLMSelfAttention(config)
+ self.output = LayoutLMSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ **kwargs,
+ ) -> tuple[torch.Tensor]:
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ **kwargs,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate
+class LayoutLMIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->LayoutLM
+class LayoutLMOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->LayoutLM
+class LayoutLMLayer(GradientCheckpointingLayer):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = LayoutLMAttention(config)
+ self.intermediate = LayoutLMIntermediate(config)
+ self.output = LayoutLMOutput(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ **kwargs,
+ ) -> tuple[torch.Tensor]:
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ **kwargs,
+ )
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->LayoutLM
+class LayoutLMEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([LayoutLMLayer(config) for i in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ @can_return_tuple
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPooler
+class LayoutLMPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->LayoutLM
+class LayoutLMPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->LayoutLM
+class LayoutLMLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = LayoutLMPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def _tie_weights(self):
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->LayoutLM
+class LayoutLMOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = LayoutLMLMPredictionHead(config)
+
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+@auto_docstring
+class LayoutLMPreTrainedModel(PreTrainedModel):
+ config: LayoutLMConfig
+ base_model_prefix = "layoutlm"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, LayoutLMLayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, LayoutLMLMPredictionHead):
+ module.bias.data.zero_()
+
+
+@auto_docstring
+class LayoutLMModel(LayoutLMPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = LayoutLMEmbeddings(config)
+ self.encoder = LayoutLMEncoder(config)
+ self.pooler = LayoutLMPooler(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ bbox: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ r"""
+ bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
+ Bounding boxes of each input sequence tokens. Selected in the range `[0,
+ config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
+ format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
+ y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LayoutLMModel
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
+ >>> model = LayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased")
+
+ >>> words = ["Hello", "world"]
+ >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
+
+ >>> token_boxes = []
+ >>> for word, box in zip(words, normalized_word_boxes):
+ ... word_tokens = tokenizer.tokenize(word)
+ ... token_boxes.extend([box] * len(word_tokens))
+ >>> # add bounding boxes of cls + sep tokens
+ >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
+
+ >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
+ >>> input_ids = encoding["input_ids"]
+ >>> attention_mask = encoding["attention_mask"]
+ >>> token_type_ids = encoding["token_type_ids"]
+ >>> bbox = torch.tensor([token_boxes])
+
+ >>> outputs = model(
+ ... input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids
+ ... )
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ if bbox is None:
+ bbox = torch.zeros(input_shape + (4,), dtype=torch.long, device=device)
+
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
+
+ if head_mask is not None:
+ if head_mask.dim() == 1:
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+ head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
+ elif head_mask.dim() == 2:
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
+ head_mask = head_mask.to(dtype=next(self.parameters()).dtype)
+ else:
+ head_mask = [None] * self.config.num_hidden_layers
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ bbox=bbox,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ extended_attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring
+class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
+ _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.layoutlm = LayoutLMModel(config)
+ self.cls = LayoutLMOnlyMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.layoutlm.embeddings.word_embeddings
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+ self.cls.predictions.bias = new_embeddings.bias
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ bbox: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, MaskedLMOutput]:
+ r"""
+ bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
+ Bounding boxes of each input sequence tokens. Selected in the range `[0,
+ config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
+ format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
+ y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LayoutLMForMaskedLM
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
+ >>> model = LayoutLMForMaskedLM.from_pretrained("microsoft/layoutlm-base-uncased")
+
+ >>> words = ["Hello", "[MASK]"]
+ >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
+
+ >>> token_boxes = []
+ >>> for word, box in zip(words, normalized_word_boxes):
+ ... word_tokens = tokenizer.tokenize(word)
+ ... token_boxes.extend([box] * len(word_tokens))
+ >>> # add bounding boxes of cls + sep tokens
+ >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
+
+ >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
+ >>> input_ids = encoding["input_ids"]
+ >>> attention_mask = encoding["attention_mask"]
+ >>> token_type_ids = encoding["token_type_ids"]
+ >>> bbox = torch.tensor([token_boxes])
+
+ >>> labels = tokenizer("Hello world", return_tensors="pt")["input_ids"]
+
+ >>> outputs = model(
+ ... input_ids=input_ids,
+ ... bbox=bbox,
+ ... attention_mask=attention_mask,
+ ... token_type_ids=token_type_ids,
+ ... labels=labels,
+ ... )
+
+ >>> loss = outputs.loss
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.layoutlm(
+ input_ids,
+ bbox,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(
+ prediction_scores.view(-1, self.config.vocab_size),
+ labels.view(-1),
+ )
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ LayoutLM Model with a sequence classification head on top (a linear layer on top of the pooled output) e.g. for
+ document image classification tasks such as the [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset.
+ """
+)
+class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.layoutlm = LayoutLMModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.layoutlm.embeddings.word_embeddings
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ bbox: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, SequenceClassifierOutput]:
+ r"""
+ bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
+ Bounding boxes of each input sequence tokens. Selected in the range `[0,
+ config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
+ format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
+ y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LayoutLMForSequenceClassification
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
+ >>> model = LayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased")
+
+ >>> words = ["Hello", "world"]
+ >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
+
+ >>> token_boxes = []
+ >>> for word, box in zip(words, normalized_word_boxes):
+ ... word_tokens = tokenizer.tokenize(word)
+ ... token_boxes.extend([box] * len(word_tokens))
+ >>> # add bounding boxes of cls + sep tokens
+ >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
+
+ >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
+ >>> input_ids = encoding["input_ids"]
+ >>> attention_mask = encoding["attention_mask"]
+ >>> token_type_ids = encoding["token_type_ids"]
+ >>> bbox = torch.tensor([token_boxes])
+ >>> sequence_label = torch.tensor([1])
+
+ >>> outputs = model(
+ ... input_ids=input_ids,
+ ... bbox=bbox,
+ ... attention_mask=attention_mask,
+ ... token_type_ids=token_type_ids,
+ ... labels=sequence_label,
+ ... )
+
+ >>> loss = outputs.loss
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.layoutlm(
+ input_ids=input_ids,
+ bbox=bbox,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ LayoutLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ sequence labeling (information extraction) tasks such as the [FUNSD](https://guillaumejaume.github.io/FUNSD/)
+ dataset and the [SROIE](https://rrc.cvc.uab.es/?ch=13) dataset.
+ """
+)
+class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.layoutlm = LayoutLMModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.layoutlm.embeddings.word_embeddings
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ bbox: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, TokenClassifierOutput]:
+ r"""
+ bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
+ Bounding boxes of each input sequence tokens. Selected in the range `[0,
+ config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
+ format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
+ y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LayoutLMForTokenClassification
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
+ >>> model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased")
+
+ >>> words = ["Hello", "world"]
+ >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
+
+ >>> token_boxes = []
+ >>> for word, box in zip(words, normalized_word_boxes):
+ ... word_tokens = tokenizer.tokenize(word)
+ ... token_boxes.extend([box] * len(word_tokens))
+ >>> # add bounding boxes of cls + sep tokens
+ >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
+
+ >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
+ >>> input_ids = encoding["input_ids"]
+ >>> attention_mask = encoding["attention_mask"]
+ >>> token_type_ids = encoding["token_type_ids"]
+ >>> bbox = torch.tensor([token_boxes])
+ >>> token_labels = torch.tensor([1, 1, 0, 0]).unsqueeze(0) # batch size of 1
+
+ >>> outputs = model(
+ ... input_ids=input_ids,
+ ... bbox=bbox,
+ ... attention_mask=attention_mask,
+ ... token_type_ids=token_type_ids,
+ ... labels=token_labels,
+ ... )
+
+ >>> loss = outputs.loss
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.layoutlm(
+ input_ids=input_ids,
+ bbox=bbox,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
+ def __init__(self, config, has_visual_segment_embedding=True):
+ r"""
+ has_visual_segment_embedding (`bool`, *optional*, defaults to `True`):
+ Whether or not to add visual segment embeddings.
+ """
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.layoutlm = LayoutLMModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.layoutlm.embeddings.word_embeddings
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ bbox: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, QuestionAnsweringModelOutput]:
+ r"""
+ bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
+ Bounding boxes of each input sequence tokens. Selected in the range `[0,
+ config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
+ format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
+ y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
+
+ Example:
+
+ In the example below, we prepare a question + context pair for the LayoutLM model. It will give us a prediction
+ of what it thinks the answer is (the span of the answer within the texts parsed from the image).
+
+ ```python
+ >>> from transformers import AutoTokenizer, LayoutLMForQuestionAnswering
+ >>> from datasets import load_dataset
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa", add_prefix_space=True)
+ >>> model = LayoutLMForQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="1e3ebac")
+
+ >>> dataset = load_dataset("nielsr/funsd", split="train")
+ >>> example = dataset[0]
+ >>> question = "what's his name?"
+ >>> words = example["words"]
+ >>> boxes = example["bboxes"]
+
+ >>> encoding = tokenizer(
+ ... question.split(), words, is_split_into_words=True, return_token_type_ids=True, return_tensors="pt"
+ ... )
+ >>> bbox = []
+ >>> for i, s, w in zip(encoding.input_ids[0], encoding.sequence_ids(0), encoding.word_ids(0)):
+ ... if s == 1:
+ ... bbox.append(boxes[w])
+ ... elif i == tokenizer.sep_token_id:
+ ... bbox.append([1000] * 4)
+ ... else:
+ ... bbox.append([0] * 4)
+ >>> encoding["bbox"] = torch.tensor([bbox])
+
+ >>> word_ids = encoding.word_ids(0)
+ >>> outputs = model(**encoding)
+ >>> loss = outputs.loss
+ >>> start_scores = outputs.start_logits
+ >>> end_scores = outputs.end_logits
+ >>> start, end = word_ids[start_scores.argmax(-1)], word_ids[end_scores.argmax(-1)]
+ >>> print(" ".join(words[start : end + 1]))
+ M. Hamann P. Harper, P. Martinez
+ ```"""
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.layoutlm(
+ input_ids=input_ids,
+ bbox=bbox,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "LayoutLMForMaskedLM",
+ "LayoutLMForSequenceClassification",
+ "LayoutLMForTokenClassification",
+ "LayoutLMForQuestionAnswering",
+ "LayoutLMModel",
+ "LayoutLMPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/modeling_tf_layoutlm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/modeling_tf_layoutlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6738693843be0fe9af32cfc4fe96f3e6fbceb59
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/modeling_tf_layoutlm.py
@@ -0,0 +1,1691 @@
+# coding=utf-8
+# Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 LayoutLM model."""
+
+from __future__ import annotations
+
+import math
+import warnings
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutputWithPastAndCrossAttentions,
+ TFBaseModelOutputWithPoolingAndCrossAttentions,
+ TFMaskedLMOutput,
+ TFQuestionAnsweringModelOutput,
+ TFSequenceClassifierOutput,
+ TFTokenClassifierOutput,
+)
+from ...modeling_tf_utils import (
+ TFMaskedLanguageModelingLoss,
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFQuestionAnsweringLoss,
+ TFSequenceClassificationLoss,
+ TFTokenClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from .configuration_layoutlm import LayoutLMConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LayoutLMConfig"
+
+
+class TFLayoutLMEmbeddings(keras.layers.Layer):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config: LayoutLMConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.max_position_embeddings = config.max_position_embeddings
+ self.max_2d_position_embeddings = config.max_2d_position_embeddings
+ self.initializer_range = config.initializer_range
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+
+ def build(self, input_shape=None):
+ with tf.name_scope("word_embeddings"):
+ self.weight = self.add_weight(
+ name="weight",
+ shape=[self.config.vocab_size, self.hidden_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ with tf.name_scope("token_type_embeddings"):
+ self.token_type_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.config.type_vocab_size, self.hidden_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ with tf.name_scope("position_embeddings"):
+ self.position_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.max_position_embeddings, self.hidden_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ with tf.name_scope("x_position_embeddings"):
+ self.x_position_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.max_2d_position_embeddings, self.hidden_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ with tf.name_scope("y_position_embeddings"):
+ self.y_position_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.max_2d_position_embeddings, self.hidden_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ with tf.name_scope("h_position_embeddings"):
+ self.h_position_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.max_2d_position_embeddings, self.hidden_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ with tf.name_scope("w_position_embeddings"):
+ self.w_position_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.max_2d_position_embeddings, self.hidden_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+ def call(
+ self,
+ input_ids: tf.Tensor | None = None,
+ bbox: tf.Tensor | None = None,
+ position_ids: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ training: bool = False,
+ ) -> tf.Tensor:
+ """
+ Applies embedding based on inputs tensor.
+
+ Returns:
+ final_embeddings (`tf.Tensor`): output embedding tensor.
+ """
+ assert not (input_ids is None and inputs_embeds is None)
+
+ if input_ids is not None:
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
+ inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
+
+ input_shape = shape_list(inputs_embeds)[:-1]
+
+ if token_type_ids is None:
+ token_type_ids = tf.fill(dims=input_shape, value=0)
+
+ if position_ids is None:
+ position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
+
+ if position_ids is None:
+ position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
+
+ if bbox is None:
+ bbox = tf.fill(input_shape + [4], value=0)
+ try:
+ left_position_embeddings = tf.gather(self.x_position_embeddings, bbox[:, :, 0])
+ upper_position_embeddings = tf.gather(self.y_position_embeddings, bbox[:, :, 1])
+ right_position_embeddings = tf.gather(self.x_position_embeddings, bbox[:, :, 2])
+ lower_position_embeddings = tf.gather(self.y_position_embeddings, bbox[:, :, 3])
+ except IndexError as e:
+ raise IndexError("The `bbox`coordinate values should be within 0-1000 range.") from e
+ h_position_embeddings = tf.gather(self.h_position_embeddings, bbox[:, :, 3] - bbox[:, :, 1])
+ w_position_embeddings = tf.gather(self.w_position_embeddings, bbox[:, :, 2] - bbox[:, :, 0])
+
+ position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
+ token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
+ final_embeddings = (
+ inputs_embeds
+ + position_embeds
+ + token_type_embeds
+ + left_position_embeddings
+ + upper_position_embeddings
+ + right_position_embeddings
+ + lower_position_embeddings
+ + h_position_embeddings
+ + w_position_embeddings
+ )
+ final_embeddings = self.LayerNorm(inputs=final_embeddings)
+ final_embeddings = self.dropout(inputs=final_embeddings, training=training)
+
+ return final_embeddings
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->LayoutLM
+class TFLayoutLMSelfAttention(keras.layers.Layer):
+ def __init__(self, config: LayoutLMConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number "
+ f"of attention heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
+
+ self.query = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
+ )
+ self.key = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
+ )
+ self.value = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
+ )
+ self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
+
+ self.is_decoder = config.is_decoder
+ self.config = config
+
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
+
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ encoder_hidden_states: tf.Tensor,
+ encoder_attention_mask: tf.Tensor,
+ past_key_value: tuple[tf.Tensor],
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ batch_size = shape_list(hidden_states)[0]
+ mixed_query_layer = self.query(inputs=hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
+ value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
+ value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
+ key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
+ value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
+ value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ # (batch size, num_heads, seq_len_q, seq_len_k)
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+ dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
+ attention_scores = tf.divide(attention_scores, dk)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in TFLayoutLMModel call() function)
+ attention_scores = tf.add(attention_scores, attention_mask)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(inputs=attention_probs, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = tf.multiply(attention_probs, head_mask)
+
+ attention_output = tf.matmul(attention_probs, value_layer)
+ attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
+
+ # (batch_size, seq_len_q, all_head_size)
+ attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "query", None) is not None:
+ with tf.name_scope(self.query.name):
+ self.query.build([None, None, self.config.hidden_size])
+ if getattr(self, "key", None) is not None:
+ with tf.name_scope(self.key.name):
+ self.key.build([None, None, self.config.hidden_size])
+ if getattr(self, "value", None) is not None:
+ with tf.name_scope(self.value.name):
+ self.value.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->LayoutLM
+class TFLayoutLMSelfOutput(keras.layers.Layer):
+ def __init__(self, config: LayoutLMConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+ hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->LayoutLM
+class TFLayoutLMAttention(keras.layers.Layer):
+ def __init__(self, config: LayoutLMConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.self_attention = TFLayoutLMSelfAttention(config, name="self")
+ self.dense_output = TFLayoutLMSelfOutput(config, name="output")
+
+ def prune_heads(self, heads):
+ raise NotImplementedError
+
+ def call(
+ self,
+ input_tensor: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ encoder_hidden_states: tf.Tensor,
+ encoder_attention_mask: tf.Tensor,
+ past_key_value: tuple[tf.Tensor],
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ self_outputs = self.self_attention(
+ hidden_states=input_tensor,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ attention_output = self.dense_output(
+ hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
+ )
+ # add attentions (possibly with past_key_value) if we output them
+ outputs = (attention_output,) + self_outputs[1:]
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "self_attention", None) is not None:
+ with tf.name_scope(self.self_attention.name):
+ self.self_attention.build(None)
+ if getattr(self, "dense_output", None) is not None:
+ with tf.name_scope(self.dense_output.name):
+ self.dense_output.build(None)
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->LayoutLM
+class TFLayoutLMIntermediate(keras.layers.Layer):
+ def __init__(self, config: LayoutLMConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
+ else:
+ self.intermediate_act_fn = config.hidden_act
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->LayoutLM
+class TFLayoutLMOutput(keras.layers.Layer):
+ def __init__(self, config: LayoutLMConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+ hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.intermediate_size])
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->LayoutLM
+class TFLayoutLMLayer(keras.layers.Layer):
+ def __init__(self, config: LayoutLMConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.attention = TFLayoutLMAttention(config, name="attention")
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = TFLayoutLMAttention(config, name="crossattention")
+ self.intermediate = TFLayoutLMIntermediate(config, name="intermediate")
+ self.bert_output = TFLayoutLMOutput(config, name="output")
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ encoder_hidden_states: tf.Tensor | None,
+ encoder_attention_mask: tf.Tensor | None,
+ past_key_value: tuple[tf.Tensor] | None,
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ input_tensor=hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=self_attn_past_key_value,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ input_tensor=attention_output,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ intermediate_output = self.intermediate(hidden_states=attention_output)
+ layer_output = self.bert_output(
+ hidden_states=intermediate_output, input_tensor=attention_output, training=training
+ )
+ outputs = (layer_output,) + outputs # add attentions if we output them
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "intermediate", None) is not None:
+ with tf.name_scope(self.intermediate.name):
+ self.intermediate.build(None)
+ if getattr(self, "bert_output", None) is not None:
+ with tf.name_scope(self.bert_output.name):
+ self.bert_output.build(None)
+ if getattr(self, "crossattention", None) is not None:
+ with tf.name_scope(self.crossattention.name):
+ self.crossattention.build(None)
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->LayoutLM
+class TFLayoutLMEncoder(keras.layers.Layer):
+ def __init__(self, config: LayoutLMConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.layer = [TFLayoutLMLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ encoder_hidden_states: tf.Tensor | None,
+ encoder_attention_mask: tf.Tensor | None,
+ past_key_values: tuple[tuple[tf.Tensor]] | None,
+ use_cache: bool | None,
+ output_attentions: bool,
+ output_hidden_states: bool,
+ return_dict: bool,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]:
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
+ )
+
+ return TFBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layer", None) is not None:
+ for layer in self.layer:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->LayoutLM
+class TFLayoutLMPooler(keras.layers.Layer):
+ def __init__(self, config: LayoutLMConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation="tanh",
+ name="dense",
+ )
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(inputs=first_token_tensor)
+
+ return pooled_output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->LayoutLM
+class TFLayoutLMPredictionHeadTransform(keras.layers.Layer):
+ def __init__(self, config: LayoutLMConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="dense",
+ )
+
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = get_tf_activation(config.hidden_act)
+ else:
+ self.transform_act_fn = config.hidden_act
+
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(inputs=hidden_states)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMPredictionHead with Bert->LayoutLM
+class TFLayoutLMLMPredictionHead(keras.layers.Layer):
+ def __init__(self, config: LayoutLMConfig, input_embeddings: keras.layers.Layer, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.hidden_size = config.hidden_size
+
+ self.transform = TFLayoutLMPredictionHeadTransform(config, name="transform")
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.input_embeddings = input_embeddings
+
+ def build(self, input_shape=None):
+ self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "transform", None) is not None:
+ with tf.name_scope(self.transform.name):
+ self.transform.build(None)
+
+ def get_output_embeddings(self) -> keras.layers.Layer:
+ return self.input_embeddings
+
+ def set_output_embeddings(self, value: tf.Variable):
+ self.input_embeddings.weight = value
+ self.input_embeddings.vocab_size = shape_list(value)[0]
+
+ def get_bias(self) -> dict[str, tf.Variable]:
+ return {"bias": self.bias}
+
+ def set_bias(self, value: tf.Variable):
+ self.bias = value["bias"]
+ self.config.vocab_size = shape_list(value["bias"])[0]
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.transform(hidden_states=hidden_states)
+ seq_length = shape_list(hidden_states)[1]
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])
+ hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
+ hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
+
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertMLMHead with Bert->LayoutLM
+class TFLayoutLMMLMHead(keras.layers.Layer):
+ def __init__(self, config: LayoutLMConfig, input_embeddings: keras.layers.Layer, **kwargs):
+ super().__init__(**kwargs)
+
+ self.predictions = TFLayoutLMLMPredictionHead(config, input_embeddings, name="predictions")
+
+ def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
+ prediction_scores = self.predictions(hidden_states=sequence_output)
+
+ return prediction_scores
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "predictions", None) is not None:
+ with tf.name_scope(self.predictions.name):
+ self.predictions.build(None)
+
+
+@keras_serializable
+class TFLayoutLMMainLayer(keras.layers.Layer):
+ config_class = LayoutLMConfig
+
+ def __init__(self, config: LayoutLMConfig, add_pooling_layer: bool = True, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+
+ self.embeddings = TFLayoutLMEmbeddings(config, name="embeddings")
+ self.encoder = TFLayoutLMEncoder(config, name="encoder")
+ self.pooler = TFLayoutLMPooler(config, name="pooler") if add_pooling_layer else None
+
+ def get_input_embeddings(self) -> keras.layers.Layer:
+ return self.embeddings
+
+ def set_input_embeddings(self, value: tf.Variable):
+ self.embeddings.weight = value
+ self.embeddings.vocab_size = shape_list(value)[0]
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ raise NotImplementedError
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ bbox: np.ndarray | tf.Tensor | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]:
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if attention_mask is None:
+ attention_mask = tf.fill(dims=input_shape, value=1)
+
+ if token_type_ids is None:
+ token_type_ids = tf.fill(dims=input_shape, value=0)
+ if bbox is None:
+ bbox = tf.fill(dims=input_shape + [4], value=0)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ bbox=bbox,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ training=training,
+ )
+
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
+ one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
+ ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
+ extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.config.num_hidden_layers
+
+ encoder_outputs = self.encoder(
+ hidden_states=embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ # Need to pass these required positional arguments to `Encoder`
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=False,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (
+ sequence_output,
+ pooled_output,
+ ) + encoder_outputs[1:]
+
+ return TFBaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "pooler", None) is not None:
+ with tf.name_scope(self.pooler.name):
+ self.pooler.build(None)
+
+
+class TFLayoutLMPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = LayoutLMConfig
+ base_model_prefix = "layoutlm"
+
+ @property
+ def input_signature(self):
+ signature = super().input_signature
+ signature["bbox"] = tf.TensorSpec(shape=(None, None, 4), dtype=tf.int32, name="bbox")
+ return signature
+
+
+LAYOUTLM_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Args:
+ config ([`LayoutLMConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+LAYOUTLM_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ bbox (`Numpy array` or `tf.Tensor` of shape `({0}, 4)`, *optional*):
+ Bounding Boxes of each input sequence tokens. Selected in the range `[0, config.max_2d_position_embeddings-
+ 1]`.
+ attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare LayoutLM Model transformer outputting raw hidden-states without any specific head on top.",
+ LAYOUTLM_START_DOCSTRING,
+)
+class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
+ def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.layoutlm = TFLayoutLMMainLayer(config, name="layoutlm")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(
+ output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ bbox: np.ndarray | tf.Tensor | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, TFLayoutLMModel
+ >>> import tensorflow as tf
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
+ >>> model = TFLayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased")
+
+ >>> words = ["Hello", "world"]
+ >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
+
+ >>> token_boxes = []
+ >>> for word, box in zip(words, normalized_word_boxes):
+ ... word_tokens = tokenizer.tokenize(word)
+ ... token_boxes.extend([box] * len(word_tokens))
+ >>> # add bounding boxes of cls + sep tokens
+ >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
+
+ >>> encoding = tokenizer(" ".join(words), return_tensors="tf")
+ >>> input_ids = encoding["input_ids"]
+ >>> attention_mask = encoding["attention_mask"]
+ >>> token_type_ids = encoding["token_type_ids"]
+ >>> bbox = tf.convert_to_tensor([token_boxes])
+
+ >>> outputs = model(
+ ... input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids
+ ... )
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ outputs = self.layoutlm(
+ input_ids=input_ids,
+ bbox=bbox,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layoutlm", None) is not None:
+ with tf.name_scope(self.layoutlm.name):
+ self.layoutlm.build(None)
+
+
+@add_start_docstrings("""LayoutLM Model with a `language modeling` head on top.""", LAYOUTLM_START_DOCSTRING)
+class TFLayoutLMForMaskedLM(TFLayoutLMPreTrainedModel, TFMaskedLanguageModelingLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [
+ r"pooler",
+ r"cls.seq_relationship",
+ r"cls.predictions.decoder.weight",
+ r"nsp___cls",
+ ]
+
+ def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `TFLayoutLMForMaskedLM` make sure `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.layoutlm = TFLayoutLMMainLayer(config, add_pooling_layer=True, name="layoutlm")
+ self.mlm = TFLayoutLMMLMHead(config, input_embeddings=self.layoutlm.embeddings, name="mlm___cls")
+
+ def get_lm_head(self) -> keras.layers.Layer:
+ return self.mlm.predictions
+
+ def get_prefix_bias_name(self) -> str:
+ warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
+ return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ bbox: np.ndarray | tf.Tensor | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFMaskedLMOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, TFLayoutLMForMaskedLM
+ >>> import tensorflow as tf
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
+ >>> model = TFLayoutLMForMaskedLM.from_pretrained("microsoft/layoutlm-base-uncased")
+
+ >>> words = ["Hello", "[MASK]"]
+ >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
+
+ >>> token_boxes = []
+ >>> for word, box in zip(words, normalized_word_boxes):
+ ... word_tokens = tokenizer.tokenize(word)
+ ... token_boxes.extend([box] * len(word_tokens))
+ >>> # add bounding boxes of cls + sep tokens
+ >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
+
+ >>> encoding = tokenizer(" ".join(words), return_tensors="tf")
+ >>> input_ids = encoding["input_ids"]
+ >>> attention_mask = encoding["attention_mask"]
+ >>> token_type_ids = encoding["token_type_ids"]
+ >>> bbox = tf.convert_to_tensor([token_boxes])
+
+ >>> labels = tokenizer("Hello world", return_tensors="tf")["input_ids"]
+
+ >>> outputs = model(
+ ... input_ids=input_ids,
+ ... bbox=bbox,
+ ... attention_mask=attention_mask,
+ ... token_type_ids=token_type_ids,
+ ... labels=labels,
+ ... )
+
+ >>> loss = outputs.loss
+ ```"""
+ outputs = self.layoutlm(
+ input_ids=input_ids,
+ bbox=bbox,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFMaskedLMOutput(
+ loss=loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layoutlm", None) is not None:
+ with tf.name_scope(self.layoutlm.name):
+ self.layoutlm.build(None)
+ if getattr(self, "mlm", None) is not None:
+ with tf.name_scope(self.mlm.name):
+ self.mlm.build(None)
+
+
+@add_start_docstrings(
+ """
+ LayoutLM Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ LAYOUTLM_START_DOCSTRING,
+)
+class TFLayoutLMForSequenceClassification(TFLayoutLMPreTrainedModel, TFSequenceClassificationLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
+ _keys_to_ignore_on_load_missing = [r"dropout"]
+
+ def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+
+ self.layoutlm = TFLayoutLMMainLayer(config, name="layoutlm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.classifier = keras.layers.Dense(
+ units=config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="classifier",
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ bbox: np.ndarray | tf.Tensor | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, TFLayoutLMForSequenceClassification
+ >>> import tensorflow as tf
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
+ >>> model = TFLayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased")
+
+ >>> words = ["Hello", "world"]
+ >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
+
+ >>> token_boxes = []
+ >>> for word, box in zip(words, normalized_word_boxes):
+ ... word_tokens = tokenizer.tokenize(word)
+ ... token_boxes.extend([box] * len(word_tokens))
+ >>> # add bounding boxes of cls + sep tokens
+ >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
+
+ >>> encoding = tokenizer(" ".join(words), return_tensors="tf")
+ >>> input_ids = encoding["input_ids"]
+ >>> attention_mask = encoding["attention_mask"]
+ >>> token_type_ids = encoding["token_type_ids"]
+ >>> bbox = tf.convert_to_tensor([token_boxes])
+ >>> sequence_label = tf.convert_to_tensor([1])
+
+ >>> outputs = model(
+ ... input_ids=input_ids,
+ ... bbox=bbox,
+ ... attention_mask=attention_mask,
+ ... token_type_ids=token_type_ids,
+ ... labels=sequence_label,
+ ... )
+
+ >>> loss = outputs.loss
+ >>> logits = outputs.logits
+ ```"""
+ outputs = self.layoutlm(
+ input_ids=input_ids,
+ bbox=bbox,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ pooled_output = outputs[1]
+ pooled_output = self.dropout(inputs=pooled_output, training=training)
+ logits = self.classifier(inputs=pooled_output)
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layoutlm", None) is not None:
+ with tf.name_scope(self.layoutlm.name):
+ self.layoutlm.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ LayoutLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ LAYOUTLM_START_DOCSTRING,
+)
+class TFLayoutLMForTokenClassification(TFLayoutLMPreTrainedModel, TFTokenClassificationLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [
+ r"pooler",
+ r"mlm___cls",
+ r"nsp___cls",
+ r"cls.predictions",
+ r"cls.seq_relationship",
+ ]
+ _keys_to_ignore_on_load_missing = [r"dropout"]
+
+ def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+
+ self.layoutlm = TFLayoutLMMainLayer(config, add_pooling_layer=True, name="layoutlm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.classifier = keras.layers.Dense(
+ units=config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="classifier",
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TFTokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ bbox: np.ndarray | tf.Tensor | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFTokenClassifierOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> import tensorflow as tf
+ >>> from transformers import AutoTokenizer, TFLayoutLMForTokenClassification
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
+ >>> model = TFLayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased")
+
+ >>> words = ["Hello", "world"]
+ >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
+
+ >>> token_boxes = []
+ >>> for word, box in zip(words, normalized_word_boxes):
+ ... word_tokens = tokenizer.tokenize(word)
+ ... token_boxes.extend([box] * len(word_tokens))
+ >>> # add bounding boxes of cls + sep tokens
+ >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
+
+ >>> encoding = tokenizer(" ".join(words), return_tensors="tf")
+ >>> input_ids = encoding["input_ids"]
+ >>> attention_mask = encoding["attention_mask"]
+ >>> token_type_ids = encoding["token_type_ids"]
+ >>> bbox = tf.convert_to_tensor([token_boxes])
+ >>> token_labels = tf.convert_to_tensor([1, 1, 0, 0])
+
+ >>> outputs = model(
+ ... input_ids=input_ids,
+ ... bbox=bbox,
+ ... attention_mask=attention_mask,
+ ... token_type_ids=token_type_ids,
+ ... labels=token_labels,
+ ... )
+
+ >>> loss = outputs.loss
+ >>> logits = outputs.logits
+ ```"""
+ outputs = self.layoutlm(
+ input_ids=input_ids,
+ bbox=bbox,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(inputs=sequence_output, training=training)
+ logits = self.classifier(inputs=sequence_output)
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFTokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layoutlm", None) is not None:
+ with tf.name_scope(self.layoutlm.name):
+ self.layoutlm.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ LayoutLM Model with a span classification head on top for extractive question-answering tasks such as
+ [DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the final hidden-states output to compute `span
+ start logits` and `span end logits`).
+ """,
+ LAYOUTLM_START_DOCSTRING,
+)
+class TFLayoutLMForQuestionAnswering(TFLayoutLMPreTrainedModel, TFQuestionAnsweringLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [
+ r"pooler",
+ r"mlm___cls",
+ r"nsp___cls",
+ r"cls.predictions",
+ r"cls.seq_relationship",
+ ]
+
+ def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+
+ self.layoutlm = TFLayoutLMMainLayer(config, add_pooling_layer=True, name="layoutlm")
+ self.qa_outputs = keras.layers.Dense(
+ units=config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="qa_outputs",
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TFQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ bbox: np.ndarray | tf.Tensor | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ start_positions: np.ndarray | tf.Tensor | None = None,
+ end_positions: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]:
+ r"""
+ start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> import tensorflow as tf
+ >>> from transformers import AutoTokenizer, TFLayoutLMForQuestionAnswering
+ >>> from datasets import load_dataset
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa", add_prefix_space=True)
+ >>> model = TFLayoutLMForQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="1e3ebac")
+
+ >>> dataset = load_dataset("nielsr/funsd", split="train")
+ >>> example = dataset[0]
+ >>> question = "what's his name?"
+ >>> words = example["words"]
+ >>> boxes = example["bboxes"]
+
+ >>> encoding = tokenizer(
+ ... question.split(), words, is_split_into_words=True, return_token_type_ids=True, return_tensors="tf"
+ ... )
+ >>> bbox = []
+ >>> for i, s, w in zip(encoding.input_ids[0], encoding.sequence_ids(0), encoding.word_ids(0)):
+ ... if s == 1:
+ ... bbox.append(boxes[w])
+ ... elif i == tokenizer.sep_token_id:
+ ... bbox.append([1000] * 4)
+ ... else:
+ ... bbox.append([0] * 4)
+ >>> encoding["bbox"] = tf.convert_to_tensor([bbox])
+
+ >>> word_ids = encoding.word_ids(0)
+ >>> outputs = model(**encoding)
+ >>> loss = outputs.loss
+ >>> start_scores = outputs.start_logits
+ >>> end_scores = outputs.end_logits
+ >>> start, end = word_ids[tf.math.argmax(start_scores, -1)[0]], word_ids[tf.math.argmax(end_scores, -1)[0]]
+ >>> print(" ".join(words[start : end + 1]))
+ M. Hamann P. Harper, P. Martinez
+ ```"""
+
+ outputs = self.layoutlm(
+ input_ids=input_ids,
+ bbox=bbox,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(inputs=sequence_output)
+ start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
+ start_logits = tf.squeeze(input=start_logits, axis=-1)
+ end_logits = tf.squeeze(input=end_logits, axis=-1)
+ loss = None
+
+ if start_positions is not None and end_positions is not None:
+ labels = {"start_position": start_positions}
+ labels["end_position"] = end_positions
+ loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFQuestionAnsweringModelOutput(
+ loss=loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layoutlm", None) is not None:
+ with tf.name_scope(self.layoutlm.name):
+ self.layoutlm.build(None)
+ if getattr(self, "qa_outputs", None) is not None:
+ with tf.name_scope(self.qa_outputs.name):
+ self.qa_outputs.build([None, None, self.config.hidden_size])
+
+
+__all__ = [
+ "TFLayoutLMForMaskedLM",
+ "TFLayoutLMForSequenceClassification",
+ "TFLayoutLMForTokenClassification",
+ "TFLayoutLMForQuestionAnswering",
+ "TFLayoutLMMainLayer",
+ "TFLayoutLMModel",
+ "TFLayoutLMPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/tokenization_layoutlm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/tokenization_layoutlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..4caccd691d0e5fcb64637f72d5c2860f6f096e9e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/tokenization_layoutlm.py
@@ -0,0 +1,483 @@
+# coding=utf-8
+# Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for model LayoutLM."""
+
+import collections
+import os
+import unicodedata
+from typing import Optional
+
+from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+
+# Copied from transformers.models.bert.tokenization_bert.load_vocab
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+# Copied from transformers.models.bert.tokenization_bert.BertTokenizer with Bert->LayoutLM,BERT->LayoutLM
+class LayoutLMTokenizer(PreTrainedTokenizer):
+ r"""
+ Construct a LayoutLM tokenizer. Based on WordPiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ do_basic_tokenize (`bool`, *optional*, defaults to `True`):
+ Whether or not to do basic tokenization before WordPiece.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original LayoutLM).
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
+ Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
+ extra spaces.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=True,
+ do_basic_tokenize=True,
+ never_split=None,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ clean_up_tokenization_spaces=True,
+ **kwargs,
+ ):
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = LayoutLMTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
+ self.do_basic_tokenize = do_basic_tokenize
+ if do_basic_tokenize:
+ self.basic_tokenizer = BasicTokenizer(
+ do_lower_case=do_lower_case,
+ never_split=never_split,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ )
+
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
+
+ super().__init__(
+ do_lower_case=do_lower_case,
+ do_basic_tokenize=do_basic_tokenize,
+ never_split=never_split,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ @property
+ def do_lower_case(self):
+ return self.basic_tokenizer.do_lower_case
+
+ @property
+ def vocab_size(self):
+ return len(self.vocab)
+
+ def get_vocab(self):
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+ def _tokenize(self, text, split_special_tokens=False):
+ split_tokens = []
+ if self.do_basic_tokenize:
+ for token in self.basic_tokenizer.tokenize(
+ text, never_split=self.all_special_tokens if not split_special_tokens else None
+ ):
+ # If the token is part of the never_split set
+ if token in self.basic_tokenizer.never_split:
+ split_tokens.append(token)
+ else:
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
+ else:
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
+ return split_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.ids_to_tokens.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ out_string = " ".join(tokens).replace(" ##", "").strip()
+ return out_string
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A LayoutLM sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ index = 0
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ else:
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!"
+ )
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+ return (vocab_file,)
+
+
+# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
+class BasicTokenizer:
+ """
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
+
+ Args:
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ do_split_on_punc (`bool`, *optional*, defaults to `True`):
+ In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
+ the full context of the words, such as contractions.
+ """
+
+ def __init__(
+ self,
+ do_lower_case=True,
+ never_split=None,
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ do_split_on_punc=True,
+ ):
+ if never_split is None:
+ never_split = []
+ self.do_lower_case = do_lower_case
+ self.never_split = set(never_split)
+ self.tokenize_chinese_chars = tokenize_chinese_chars
+ self.strip_accents = strip_accents
+ self.do_split_on_punc = do_split_on_punc
+
+ def tokenize(self, text, never_split=None):
+ """
+ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
+
+ Args:
+ never_split (`List[str]`, *optional*)
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
+ """
+ # union() returns a new set by concatenating the two sets.
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
+ text = self._clean_text(text)
+
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ if self.tokenize_chinese_chars:
+ text = self._tokenize_chinese_chars(text)
+ # prevents treating the same character with different unicode codepoints as different characters
+ unicode_normalized_text = unicodedata.normalize("NFC", text)
+ orig_tokens = whitespace_tokenize(unicode_normalized_text)
+ split_tokens = []
+ for token in orig_tokens:
+ if token not in never_split:
+ if self.do_lower_case:
+ token = token.lower()
+ if self.strip_accents is not False:
+ token = self._run_strip_accents(token)
+ elif self.strip_accents:
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text, never_split=None):
+ """Splits punctuation on a piece of text."""
+ if not self.do_split_on_punc or (never_split is not None and text in never_split):
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if (
+ (cp >= 0x4E00 and cp <= 0x9FFF)
+ or (cp >= 0x3400 and cp <= 0x4DBF)
+ or (cp >= 0x20000 and cp <= 0x2A6DF)
+ or (cp >= 0x2A700 and cp <= 0x2B73F)
+ or (cp >= 0x2B740 and cp <= 0x2B81F)
+ or (cp >= 0x2B820 and cp <= 0x2CEAF)
+ or (cp >= 0xF900 and cp <= 0xFAFF)
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
+ ):
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
+class WordpieceTokenizer:
+ """Runs WordPiece tokenization."""
+
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
+ tokenization using the given vocabulary.
+
+ For example, `input = "unaffable"` will return as output `["un", "##aff", "##able"]`.
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through *BasicTokenizer*.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ return output_tokens
+
+
+__all__ = ["LayoutLMTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/tokenization_layoutlm_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/tokenization_layoutlm_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7ade6e0b8cdf04f2d4f06b6191b93a8ed7ee2a6
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlm/tokenization_layoutlm_fast.py
@@ -0,0 +1,147 @@
+# coding=utf-8
+# Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for model LayoutLM."""
+
+import json
+from typing import Optional
+
+from tokenizers import normalizers
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+from .tokenization_layoutlm import LayoutLMTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
+
+
+# Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast with Bert->LayoutLM,BERT->LayoutLM
+class LayoutLMTokenizerFast(PreTrainedTokenizerFast):
+ r"""
+ Construct a "fast" LayoutLM tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ clean_text (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the text before tokenization by removing any control characters and replacing all
+ whitespaces by the classic one.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this
+ issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original LayoutLM).
+ wordpieces_prefix (`str`, *optional*, defaults to `"##"`):
+ The prefix for subwords.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = LayoutLMTokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ do_lower_case=True,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file,
+ tokenizer_file=tokenizer_file,
+ do_lower_case=do_lower_case,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ **kwargs,
+ )
+
+ normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
+ if (
+ normalizer_state.get("lowercase", do_lower_case) != do_lower_case
+ or normalizer_state.get("strip_accents", strip_accents) != strip_accents
+ or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
+ ):
+ normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
+ normalizer_state["lowercase"] = do_lower_case
+ normalizer_state["strip_accents"] = strip_accents
+ normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
+ self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
+
+ self.do_lower_case = do_lower_case
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A LayoutLM sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+
+ if token_ids_1 is not None:
+ output += token_ids_1 + [self.sep_token_id]
+
+ return output
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+
+__all__ = ["LayoutLMTokenizerFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b68a523c0b0c362d9930f5bee492cea73f3937f0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/__init__.py
@@ -0,0 +1,33 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_layoutlmv2 import *
+ from .feature_extraction_layoutlmv2 import *
+ from .image_processing_layoutlmv2 import *
+ from .image_processing_layoutlmv2_fast import *
+ from .modeling_layoutlmv2 import *
+ from .processing_layoutlmv2 import *
+ from .tokenization_layoutlmv2 import *
+ from .tokenization_layoutlmv2_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/configuration_layoutlmv2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/configuration_layoutlmv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..b729ddbb1d429f3cfe464f303ec55ae391d498c3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/configuration_layoutlmv2.py
@@ -0,0 +1,222 @@
+# coding=utf-8
+# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""LayoutLMv2 model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import is_detectron2_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+# soft dependency
+if is_detectron2_available():
+ import detectron2
+
+
+class LayoutLMv2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LayoutLMv2Model`]. It is used to instantiate an
+ LayoutLMv2 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the LayoutLMv2
+ [microsoft/layoutlmv2-base-uncased](https://huggingface.co/microsoft/layoutlmv2-base-uncased) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the LayoutLMv2 model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`LayoutLMv2Model`] or [`TFLayoutLMv2Model`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimension of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`LayoutLMv2Model`] or
+ [`TFLayoutLMv2Model`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ max_2d_position_embeddings (`int`, *optional*, defaults to 1024):
+ The maximum value that the 2D position embedding might ever be used with. Typically set this to something
+ large just in case (e.g., 1024).
+ max_rel_pos (`int`, *optional*, defaults to 128):
+ The maximum number of relative positions to be used in the self-attention mechanism.
+ rel_pos_bins (`int`, *optional*, defaults to 32):
+ The number of relative position bins to be used in the self-attention mechanism.
+ fast_qkv (`bool`, *optional*, defaults to `True`):
+ Whether or not to use a single matrix for the queries, keys, values in the self-attention layers.
+ max_rel_2d_pos (`int`, *optional*, defaults to 256):
+ The maximum number of relative 2D positions in the self-attention mechanism.
+ rel_2d_pos_bins (`int`, *optional*, defaults to 64):
+ The number of 2D relative position bins in the self-attention mechanism.
+ image_feature_pool_shape (`list[int]`, *optional*, defaults to [7, 7, 256]):
+ The shape of the average-pooled feature map.
+ coordinate_size (`int`, *optional*, defaults to 128):
+ Dimension of the coordinate embeddings.
+ shape_size (`int`, *optional*, defaults to 128):
+ Dimension of the width and height embeddings.
+ has_relative_attention_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not to use a relative attention bias in the self-attention mechanism.
+ has_spatial_attention_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not to use a spatial attention bias in the self-attention mechanism.
+ has_visual_segment_embedding (`bool`, *optional*, defaults to `False`):
+ Whether or not to add visual segment embeddings.
+ detectron2_config_args (`dict`, *optional*):
+ Dictionary containing the configuration arguments of the Detectron2 visual backbone. Refer to [this
+ file](https://github.com/microsoft/unilm/blob/master/layoutlmft/layoutlmft/models/layoutlmv2/detectron2_config.py)
+ for details regarding default values.
+
+ Example:
+
+ ```python
+ >>> from transformers import LayoutLMv2Config, LayoutLMv2Model
+
+ >>> # Initializing a LayoutLMv2 microsoft/layoutlmv2-base-uncased style configuration
+ >>> configuration = LayoutLMv2Config()
+
+ >>> # Initializing a model (with random weights) from the microsoft/layoutlmv2-base-uncased style configuration
+ >>> model = LayoutLMv2Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "layoutlmv2"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=0,
+ max_2d_position_embeddings=1024,
+ max_rel_pos=128,
+ rel_pos_bins=32,
+ fast_qkv=True,
+ max_rel_2d_pos=256,
+ rel_2d_pos_bins=64,
+ convert_sync_batchnorm=True,
+ image_feature_pool_shape=[7, 7, 256],
+ coordinate_size=128,
+ shape_size=128,
+ has_relative_attention_bias=True,
+ has_spatial_attention_bias=True,
+ has_visual_segment_embedding=False,
+ detectron2_config_args=None,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ hidden_size=hidden_size,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ intermediate_size=intermediate_size,
+ hidden_act=hidden_act,
+ hidden_dropout_prob=hidden_dropout_prob,
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
+ max_position_embeddings=max_position_embeddings,
+ type_vocab_size=type_vocab_size,
+ initializer_range=initializer_range,
+ layer_norm_eps=layer_norm_eps,
+ pad_token_id=pad_token_id,
+ **kwargs,
+ )
+ self.max_2d_position_embeddings = max_2d_position_embeddings
+ self.max_rel_pos = max_rel_pos
+ self.rel_pos_bins = rel_pos_bins
+ self.fast_qkv = fast_qkv
+ self.max_rel_2d_pos = max_rel_2d_pos
+ self.rel_2d_pos_bins = rel_2d_pos_bins
+ self.convert_sync_batchnorm = convert_sync_batchnorm
+ self.image_feature_pool_shape = image_feature_pool_shape
+ self.coordinate_size = coordinate_size
+ self.shape_size = shape_size
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.has_spatial_attention_bias = has_spatial_attention_bias
+ self.has_visual_segment_embedding = has_visual_segment_embedding
+ self.detectron2_config_args = (
+ detectron2_config_args if detectron2_config_args is not None else self.get_default_detectron2_config()
+ )
+
+ @classmethod
+ def get_default_detectron2_config(cls):
+ return {
+ "MODEL.MASK_ON": True,
+ "MODEL.PIXEL_STD": [57.375, 57.120, 58.395],
+ "MODEL.BACKBONE.NAME": "build_resnet_fpn_backbone",
+ "MODEL.FPN.IN_FEATURES": ["res2", "res3", "res4", "res5"],
+ "MODEL.ANCHOR_GENERATOR.SIZES": [[32], [64], [128], [256], [512]],
+ "MODEL.RPN.IN_FEATURES": ["p2", "p3", "p4", "p5", "p6"],
+ "MODEL.RPN.PRE_NMS_TOPK_TRAIN": 2000,
+ "MODEL.RPN.PRE_NMS_TOPK_TEST": 1000,
+ "MODEL.RPN.POST_NMS_TOPK_TRAIN": 1000,
+ "MODEL.POST_NMS_TOPK_TEST": 1000,
+ "MODEL.ROI_HEADS.NAME": "StandardROIHeads",
+ "MODEL.ROI_HEADS.NUM_CLASSES": 5,
+ "MODEL.ROI_HEADS.IN_FEATURES": ["p2", "p3", "p4", "p5"],
+ "MODEL.ROI_BOX_HEAD.NAME": "FastRCNNConvFCHead",
+ "MODEL.ROI_BOX_HEAD.NUM_FC": 2,
+ "MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION": 14,
+ "MODEL.ROI_MASK_HEAD.NAME": "MaskRCNNConvUpsampleHead",
+ "MODEL.ROI_MASK_HEAD.NUM_CONV": 4,
+ "MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION": 7,
+ "MODEL.RESNETS.DEPTH": 101,
+ "MODEL.RESNETS.SIZES": [[32], [64], [128], [256], [512]],
+ "MODEL.RESNETS.ASPECT_RATIOS": [[0.5, 1.0, 2.0]],
+ "MODEL.RESNETS.OUT_FEATURES": ["res2", "res3", "res4", "res5"],
+ "MODEL.RESNETS.NUM_GROUPS": 32,
+ "MODEL.RESNETS.WIDTH_PER_GROUP": 8,
+ "MODEL.RESNETS.STRIDE_IN_1X1": False,
+ }
+
+ def get_detectron2_config(self):
+ detectron2_config = detectron2.config.get_cfg()
+ for k, v in self.detectron2_config_args.items():
+ attributes = k.split(".")
+ to_set = detectron2_config
+ for attribute in attributes[:-1]:
+ to_set = getattr(to_set, attribute)
+ setattr(to_set, attributes[-1], v)
+
+ return detectron2_config
+
+
+__all__ = ["LayoutLMv2Config"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c70e1ed643101401373ae29637d4d597873485e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py
@@ -0,0 +1,40 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Feature extractor class for LayoutLMv2.
+"""
+
+import warnings
+
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_layoutlmv2 import LayoutLMv2ImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class LayoutLMv2FeatureExtractor(LayoutLMv2ImageProcessor):
+ def __init__(self, *args, **kwargs) -> None:
+ warnings.warn(
+ "The class LayoutLMv2FeatureExtractor is deprecated and will be removed in version 5 of Transformers."
+ " Please use LayoutLMv2ImageProcessor instead.",
+ FutureWarning,
+ )
+ super().__init__(*args, **kwargs)
+
+
+__all__ = ["LayoutLMv2FeatureExtractor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/image_processing_layoutlmv2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/image_processing_layoutlmv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2e7361a6d3f5abe3a9e4f6c0d864db62837717
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/image_processing_layoutlmv2.py
@@ -0,0 +1,303 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for LayoutLMv2."""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import flip_channel_order, resize, to_channel_dimension_format, to_pil_image
+from ...image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import (
+ TensorType,
+ filter_out_non_signature_kwargs,
+ is_pytesseract_available,
+ is_vision_available,
+ logging,
+ requires_backends,
+)
+from ...utils.import_utils import requires
+
+
+if is_vision_available():
+ import PIL
+
+# soft dependency
+if is_pytesseract_available():
+ import pytesseract
+
+logger = logging.get_logger(__name__)
+
+
+def normalize_box(box, width, height):
+ return [
+ int(1000 * (box[0] / width)),
+ int(1000 * (box[1] / height)),
+ int(1000 * (box[2] / width)),
+ int(1000 * (box[3] / height)),
+ ]
+
+
+def apply_tesseract(
+ image: np.ndarray,
+ lang: Optional[str],
+ tesseract_config: Optional[str] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+):
+ """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
+ tesseract_config = tesseract_config if tesseract_config is not None else ""
+
+ # apply OCR
+ pil_image = to_pil_image(image, input_data_format=input_data_format)
+ image_width, image_height = pil_image.size
+ data = pytesseract.image_to_data(pil_image, lang=lang, output_type="dict", config=tesseract_config)
+ words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
+
+ # filter empty words and corresponding coordinates
+ irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
+ words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
+ left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
+ top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
+ width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
+ height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
+
+ # turn coordinates into (left, top, left+width, top+height) format
+ actual_boxes = []
+ for x, y, w, h in zip(left, top, width, height):
+ actual_box = [x, y, x + w, y + h]
+ actual_boxes.append(actual_box)
+
+ # finally, normalize the bounding boxes
+ normalized_boxes = []
+ for box in actual_boxes:
+ normalized_boxes.append(normalize_box(box, image_width, image_height))
+
+ assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
+
+ return words, normalized_boxes
+
+
+@requires(backends=("vision",))
+class LayoutLMv2ImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a LayoutLMv2 image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to `(size["height"], size["width"])`. Can be
+ overridden by `do_resize` in `preprocess`.
+ size (`dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
+ Size of the image after resizing. Can be overridden by `size` in `preprocess`.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ apply_ocr (`bool`, *optional*, defaults to `True`):
+ Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes. Can be overridden by
+ `apply_ocr` in `preprocess`.
+ ocr_lang (`str`, *optional*):
+ The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
+ used. Can be overridden by `ocr_lang` in `preprocess`.
+ tesseract_config (`str`, *optional*, defaults to `""`):
+ Any additional custom configuration flags that are forwarded to the `config` parameter when calling
+ Tesseract. For example: '--psm 6'. Can be overridden by `tesseract_config` in `preprocess`.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ apply_ocr: bool = True,
+ ocr_lang: Optional[str] = None,
+ tesseract_config: Optional[str] = "",
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 224, "width": 224}
+ size = get_size_dict(size)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.apply_ocr = apply_ocr
+ self.ocr_lang = ocr_lang
+ self.tesseract_config = tesseract_config
+
+ # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
+ output_size = (size["height"], size["width"])
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ apply_ocr: Optional[bool] = None,
+ ocr_lang: Optional[str] = None,
+ tesseract_config: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Desired size of the output image after resizing.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PIL.Image` resampling
+ filter. Only has an effect if `do_resize` is set to `True`.
+ apply_ocr (`bool`, *optional*, defaults to `self.apply_ocr`):
+ Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
+ ocr_lang (`str`, *optional*, defaults to `self.ocr_lang`):
+ The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
+ used.
+ tesseract_config (`str`, *optional*, defaults to `self.tesseract_config`):
+ Any additional custom configuration flags that are forwarded to the `config` parameter when calling
+ Tesseract.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ resample = resample if resample is not None else self.resample
+ apply_ocr = apply_ocr if apply_ocr is not None else self.apply_ocr
+ ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang
+ tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config
+
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if apply_ocr:
+ requires_backends(self, "pytesseract")
+ words_batch = []
+ boxes_batch = []
+ for image in images:
+ words, boxes = apply_tesseract(image, ocr_lang, tesseract_config, input_data_format=input_data_format)
+ words_batch.append(words)
+ boxes_batch.append(boxes)
+
+ if do_resize:
+ images = [
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ # flip color channels from RGB to BGR (as Detectron2 requires this)
+ images = [flip_channel_order(image, input_data_format=input_data_format) for image in images]
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+ ]
+
+ data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
+
+ if apply_ocr:
+ data["words"] = words_batch
+ data["boxes"] = boxes_batch
+ return data
+
+
+__all__ = ["LayoutLMv2ImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..354bbe21c4dba00fcb1b373d916950e9f643b7ab
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py
@@ -0,0 +1,135 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for LayoutLMv2."""
+
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs
+from ...image_transforms import ChannelDimension, group_images_by_shape, reorder_images
+from ...image_utils import ImageInput, PILImageResampling, SizeDict
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ auto_docstring,
+ logging,
+ requires_backends,
+)
+from .image_processing_layoutlmv2 import apply_tesseract
+
+
+logger = logging.get_logger(__name__)
+
+
+class LayoutLMv2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ Args:
+ apply_ocr (`bool`, *optional*, defaults to `True`):
+ Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes. Can be overridden by
+ the `apply_ocr` parameter in the `preprocess` method.
+ ocr_lang (`str`, *optional*):
+ The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
+ used. Can be overridden by the `ocr_lang` parameter in the `preprocess` method.
+ tesseract_config (`str`, *optional*):
+ Any additional custom configuration flags that are forwarded to the `config` parameter when calling
+ Tesseract. For example: '--psm 6'. Can be overridden by the `tesseract_config` parameter in the
+ `preprocess` method.
+ """
+
+ apply_ocr: Optional[bool]
+ ocr_lang: Optional[str]
+ tesseract_config: Optional[str]
+
+
+@auto_docstring
+class LayoutLMv2ImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ size = {"height": 224, "width": 224}
+ rescale_factor = None
+ do_resize = True
+ apply_ocr = True
+ ocr_lang = None
+ tesseract_config = ""
+ valid_kwargs = LayoutLMv2FastImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[LayoutLMv2FastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[LayoutLMv2FastImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ apply_ocr: bool,
+ ocr_lang: Optional[str],
+ tesseract_config: Optional[str],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ # Tesseract OCR to get words + normalized bounding boxes
+ if apply_ocr:
+ requires_backends(self, "pytesseract")
+ words_batch = []
+ boxes_batch = []
+ for image in images:
+ if image.is_cuda:
+ logger.warning_once(
+ "apply_ocr can only be performed on cpu. Tensors will be transferred to cpu before processing."
+ )
+ words, boxes = apply_tesseract(
+ image.cpu(), ocr_lang, tesseract_config, input_data_format=ChannelDimension.FIRST
+ )
+ words_batch.append(words)
+ boxes_batch.append(boxes)
+
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ # flip color channels from RGB to BGR (as Detectron2 requires this)
+ stacked_images = stacked_images.flip(1)
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ data = BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+ if apply_ocr:
+ data["words"] = words_batch
+ data["boxes"] = boxes_batch
+
+ return data
+
+
+__all__ = ["LayoutLMv2ImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/modeling_layoutlmv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f444fbb6b281b42477d730881e719586cfc3522
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/modeling_layoutlmv2.py
@@ -0,0 +1,1394 @@
+# coding=utf-8
+# Copyright 2021 Microsoft Research The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch LayoutLMv2 model."""
+
+import math
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPooling,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import apply_chunking_to_forward
+from ...utils import auto_docstring, is_detectron2_available, logging, requires_backends
+from .configuration_layoutlmv2 import LayoutLMv2Config
+
+
+# soft dependency
+if is_detectron2_available():
+ import detectron2
+ from detectron2.modeling import META_ARCH_REGISTRY
+
+ # This is needed as otherwise their overload will break sequential loading by overwriting buffer over and over. See
+ # https://github.com/facebookresearch/detectron2/blob/9604f5995cc628619f0e4fd913453b4d7d61db3f/detectron2/layers/batch_norm.py#L83-L86
+ detectron2.layers.batch_norm.FrozenBatchNorm2d._load_from_state_dict = torch.nn.Module._load_from_state_dict
+
+logger = logging.get_logger(__name__)
+
+
+class LayoutLMv2Embeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
+ self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
+ self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
+ self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+
+ def _calc_spatial_position_embeddings(self, bbox):
+ try:
+ left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
+ upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
+ right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
+ lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
+ except IndexError as e:
+ raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e
+
+ h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
+ w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
+
+ spatial_position_embeddings = torch.cat(
+ [
+ left_position_embeddings,
+ upper_position_embeddings,
+ right_position_embeddings,
+ lower_position_embeddings,
+ h_position_embeddings,
+ w_position_embeddings,
+ ],
+ dim=-1,
+ )
+ return spatial_position_embeddings
+
+
+class LayoutLMv2SelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+ self.fast_qkv = config.fast_qkv
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.has_relative_attention_bias = config.has_relative_attention_bias
+ self.has_spatial_attention_bias = config.has_spatial_attention_bias
+
+ if config.fast_qkv:
+ self.qkv_linear = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=False)
+ self.q_bias = nn.Parameter(torch.zeros(1, 1, self.all_head_size))
+ self.v_bias = nn.Parameter(torch.zeros(1, 1, self.all_head_size))
+ else:
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def compute_qkv(self, hidden_states):
+ if self.fast_qkv:
+ qkv = self.qkv_linear(hidden_states)
+ q, k, v = torch.chunk(qkv, 3, dim=-1)
+ if q.ndimension() == self.q_bias.ndimension():
+ q = q + self.q_bias
+ v = v + self.v_bias
+ else:
+ _sz = (1,) * (q.ndimension() - 1) + (-1,)
+ q = q + self.q_bias.view(*_sz)
+ v = v + self.v_bias.view(*_sz)
+ else:
+ q = self.query(hidden_states)
+ k = self.key(hidden_states)
+ v = self.value(hidden_states)
+ return q, k, v
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ rel_pos=None,
+ rel_2d_pos=None,
+ ):
+ batch_size, seq_length, _ = hidden_states.shape
+ query, key, value = self.compute_qkv(hidden_states)
+
+ # (B, L, H*D) -> (B, H, L, D)
+ query_layer = query.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
+ key_layer = key.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
+ value_layer = value.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
+
+ query_layer = query_layer / math.sqrt(self.attention_head_size)
+ # [BSZ, NAT, L, L]
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ if self.has_relative_attention_bias:
+ attention_scores += rel_pos
+ if self.has_spatial_attention_bias:
+ attention_scores += rel_2d_pos
+ attention_scores = attention_scores.float().masked_fill_(
+ attention_mask.to(torch.bool), torch.finfo(attention_scores.dtype).min
+ )
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).type_as(value_layer)
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+ return outputs
+
+
+class LayoutLMv2Attention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = LayoutLMv2SelfAttention(config)
+ self.output = LayoutLMv2SelfOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ rel_pos=None,
+ rel_2d_pos=None,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions,
+ rel_pos=rel_pos,
+ rel_2d_pos=rel_2d_pos,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class LayoutLMv2SelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->LayoutLMv2
+class LayoutLMv2Intermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->LayoutLM
+class LayoutLMv2Output(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class LayoutLMv2Layer(GradientCheckpointingLayer):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = LayoutLMv2Attention(config)
+ self.intermediate = LayoutLMv2Intermediate(config)
+ self.output = LayoutLMv2Output(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ rel_pos=None,
+ rel_2d_pos=None,
+ ):
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ rel_pos=rel_pos,
+ rel_2d_pos=rel_2d_pos,
+ )
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+def relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small
+ absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions
+ >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should
+ allow for more graceful generalization to longer sequences than the model has been trained on.
+
+ Args:
+ relative_position: an int32 Tensor
+ bidirectional: a boolean - whether the attention is bidirectional
+ num_buckets: an integer
+ max_distance: an integer
+
+ Returns:
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
+ """
+
+ ret = 0
+ if bidirectional:
+ num_buckets //= 2
+ ret += (relative_position > 0).long() * num_buckets
+ n = torch.abs(relative_position)
+ else:
+ n = torch.max(-relative_position, torch.zeros_like(relative_position))
+ # now n is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ val_if_large = max_exact + (
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
+ ).to(torch.long)
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+
+class LayoutLMv2Encoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([LayoutLMv2Layer(config) for _ in range(config.num_hidden_layers)])
+
+ self.has_relative_attention_bias = config.has_relative_attention_bias
+ self.has_spatial_attention_bias = config.has_spatial_attention_bias
+
+ if self.has_relative_attention_bias:
+ self.rel_pos_bins = config.rel_pos_bins
+ self.max_rel_pos = config.max_rel_pos
+ self.rel_pos_bias = nn.Linear(self.rel_pos_bins, config.num_attention_heads, bias=False)
+
+ if self.has_spatial_attention_bias:
+ self.max_rel_2d_pos = config.max_rel_2d_pos
+ self.rel_2d_pos_bins = config.rel_2d_pos_bins
+ self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
+ self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
+
+ self.gradient_checkpointing = False
+
+ def _calculate_1d_position_embeddings(self, position_ids):
+ rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
+ rel_pos = relative_position_bucket(
+ rel_pos_mat,
+ num_buckets=self.rel_pos_bins,
+ max_distance=self.max_rel_pos,
+ )
+ # Since this is a simple indexing operation that is independent of the input,
+ # no need to track gradients for this operation
+ #
+ # Without this no_grad context, training speed slows down significantly
+ with torch.no_grad():
+ rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
+ rel_pos = rel_pos.contiguous()
+ return rel_pos
+
+ def _calculate_2d_position_embeddings(self, bbox):
+ position_coord_x = bbox[:, :, 0]
+ position_coord_y = bbox[:, :, 3]
+ rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
+ rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)
+ rel_pos_x = relative_position_bucket(
+ rel_pos_x_2d_mat,
+ num_buckets=self.rel_2d_pos_bins,
+ max_distance=self.max_rel_2d_pos,
+ )
+ rel_pos_y = relative_position_bucket(
+ rel_pos_y_2d_mat,
+ num_buckets=self.rel_2d_pos_bins,
+ max_distance=self.max_rel_2d_pos,
+ )
+ # Since this is a simple indexing operation that is independent of the input,
+ # no need to track gradients for this operation
+ #
+ # Without this no_grad context, training speed slows down significantly
+ with torch.no_grad():
+ rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
+ rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
+ rel_pos_x = rel_pos_x.contiguous()
+ rel_pos_y = rel_pos_y.contiguous()
+ rel_2d_pos = rel_pos_x + rel_pos_y
+ return rel_2d_pos
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ bbox=None,
+ position_ids=None,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ rel_pos = self._calculate_1d_position_embeddings(position_ids) if self.has_relative_attention_bias else None
+ rel_2d_pos = self._calculate_2d_position_embeddings(bbox) if self.has_spatial_attention_bias else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ output_attentions,
+ rel_pos=rel_pos,
+ rel_2d_pos=rel_2d_pos,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ all_hidden_states,
+ all_self_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+@auto_docstring
+class LayoutLMv2PreTrainedModel(PreTrainedModel):
+ config: LayoutLMv2Config
+ base_model_prefix = "layoutlmv2"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, LayoutLMv2SelfAttention):
+ if self.config.fast_qkv:
+ module.q_bias.data.zero_()
+ module.v_bias.data.zero_()
+ elif isinstance(module, LayoutLMv2Model):
+ if hasattr(module, "visual_segment_embedding"):
+ module.visual_segment_embedding.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+
+def my_convert_sync_batchnorm(module, process_group=None):
+ # same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d`
+ if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
+ return nn.modules.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
+ module_output = module
+ if isinstance(module, detectron2.layers.FrozenBatchNorm2d):
+ module_output = torch.nn.SyncBatchNorm(
+ num_features=module.num_features,
+ eps=module.eps,
+ affine=True,
+ track_running_stats=True,
+ process_group=process_group,
+ )
+ module_output.weight = torch.nn.Parameter(module.weight)
+ module_output.bias = torch.nn.Parameter(module.bias)
+ module_output.running_mean = module.running_mean
+ module_output.running_var = module.running_var
+ module_output.num_batches_tracked = torch.tensor(0, dtype=torch.long, device=module.running_mean.device)
+ for name, child in module.named_children():
+ module_output.add_module(name, my_convert_sync_batchnorm(child, process_group))
+ del module
+ return module_output
+
+
+class LayoutLMv2VisualBackbone(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.cfg = config.get_detectron2_config()
+ meta_arch = self.cfg.MODEL.META_ARCHITECTURE
+ model = META_ARCH_REGISTRY.get(meta_arch)(self.cfg)
+ assert isinstance(model.backbone, detectron2.modeling.backbone.FPN)
+ self.backbone = model.backbone
+
+ assert len(self.cfg.MODEL.PIXEL_MEAN) == len(self.cfg.MODEL.PIXEL_STD)
+ num_channels = len(self.cfg.MODEL.PIXEL_MEAN)
+ self.register_buffer(
+ "pixel_mean",
+ torch.Tensor(self.cfg.MODEL.PIXEL_MEAN).view(num_channels, 1, 1),
+ persistent=False,
+ )
+ self.register_buffer(
+ "pixel_std", torch.Tensor(self.cfg.MODEL.PIXEL_STD).view(num_channels, 1, 1), persistent=False
+ )
+ self.out_feature_key = "p2"
+ if torch.are_deterministic_algorithms_enabled():
+ logger.warning("using `AvgPool2d` instead of `AdaptiveAvgPool2d`")
+ input_shape = (224, 224)
+ backbone_stride = self.backbone.output_shape()[self.out_feature_key].stride
+ self.pool = nn.AvgPool2d(
+ (
+ math.ceil(math.ceil(input_shape[0] / backbone_stride) / config.image_feature_pool_shape[0]),
+ math.ceil(math.ceil(input_shape[1] / backbone_stride) / config.image_feature_pool_shape[1]),
+ )
+ )
+ else:
+ self.pool = nn.AdaptiveAvgPool2d(config.image_feature_pool_shape[:2])
+ if len(config.image_feature_pool_shape) == 2:
+ config.image_feature_pool_shape.append(self.backbone.output_shape()[self.out_feature_key].channels)
+ assert self.backbone.output_shape()[self.out_feature_key].channels == config.image_feature_pool_shape[2]
+
+ def forward(self, images):
+ images_input = ((images if torch.is_tensor(images) else images.tensor) - self.pixel_mean) / self.pixel_std
+ features = self.backbone(images_input)
+ features = features[self.out_feature_key]
+ features = self.pool(features).flatten(start_dim=2).transpose(1, 2).contiguous()
+ return features
+
+ def synchronize_batch_norm(self):
+ if not (
+ torch.distributed.is_available()
+ and torch.distributed.is_initialized()
+ and torch.distributed.get_rank() > -1
+ ):
+ raise RuntimeError("Make sure torch.distributed is set up properly.")
+
+ self_rank = torch.distributed.get_rank()
+ node_size = torch.cuda.device_count()
+ world_size = torch.distributed.get_world_size()
+ if not (world_size % node_size == 0):
+ raise RuntimeError("Make sure the number of processes can be divided by the number of nodes")
+
+ node_global_ranks = [list(range(i * node_size, (i + 1) * node_size)) for i in range(world_size // node_size)]
+ sync_bn_groups = [
+ torch.distributed.new_group(ranks=node_global_ranks[i]) for i in range(world_size // node_size)
+ ]
+ node_rank = self_rank // node_size
+
+ self.backbone = my_convert_sync_batchnorm(self.backbone, process_group=sync_bn_groups[node_rank])
+
+
+class LayoutLMv2Pooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+@auto_docstring
+class LayoutLMv2Model(LayoutLMv2PreTrainedModel):
+ def __init__(self, config):
+ requires_backends(self, "detectron2")
+ super().__init__(config)
+ self.config = config
+ self.has_visual_segment_embedding = config.has_visual_segment_embedding
+ self.embeddings = LayoutLMv2Embeddings(config)
+
+ self.visual = LayoutLMv2VisualBackbone(config)
+ self.visual_proj = nn.Linear(config.image_feature_pool_shape[-1], config.hidden_size)
+ if self.has_visual_segment_embedding:
+ self.visual_segment_embedding = nn.Parameter(nn.Embedding(1, config.hidden_size).weight[0])
+ self.visual_LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.visual_dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ self.encoder = LayoutLMv2Encoder(config)
+ self.pooler = LayoutLMv2Pooler(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids, inputs_embeds=None):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros_like(input_ids)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embeddings.word_embeddings(input_ids)
+ position_embeddings = self.embeddings.position_embeddings(position_ids)
+ spatial_position_embeddings = self.embeddings._calc_spatial_position_embeddings(bbox)
+ token_type_embeddings = self.embeddings.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + position_embeddings + spatial_position_embeddings + token_type_embeddings
+ embeddings = self.embeddings.LayerNorm(embeddings)
+ embeddings = self.embeddings.dropout(embeddings)
+ return embeddings
+
+ def _calc_img_embeddings(self, image, bbox, position_ids):
+ visual_embeddings = self.visual_proj(self.visual(image))
+ position_embeddings = self.embeddings.position_embeddings(position_ids)
+ spatial_position_embeddings = self.embeddings._calc_spatial_position_embeddings(bbox)
+ embeddings = visual_embeddings + position_embeddings + spatial_position_embeddings
+ if self.has_visual_segment_embedding:
+ embeddings += self.visual_segment_embedding
+ embeddings = self.visual_LayerNorm(embeddings)
+ embeddings = self.visual_dropout(embeddings)
+ return embeddings
+
+ def _calc_visual_bbox(self, image_feature_pool_shape, bbox, device, final_shape):
+ visual_bbox_x = torch.div(
+ torch.arange(
+ 0,
+ 1000 * (image_feature_pool_shape[1] + 1),
+ 1000,
+ device=device,
+ dtype=bbox.dtype,
+ ),
+ self.config.image_feature_pool_shape[1],
+ rounding_mode="floor",
+ )
+ visual_bbox_y = torch.div(
+ torch.arange(
+ 0,
+ 1000 * (self.config.image_feature_pool_shape[0] + 1),
+ 1000,
+ device=device,
+ dtype=bbox.dtype,
+ ),
+ self.config.image_feature_pool_shape[0],
+ rounding_mode="floor",
+ )
+ visual_bbox = torch.stack(
+ [
+ visual_bbox_x[:-1].repeat(image_feature_pool_shape[0], 1),
+ visual_bbox_y[:-1].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
+ visual_bbox_x[1:].repeat(image_feature_pool_shape[0], 1),
+ visual_bbox_y[1:].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
+ ],
+ dim=-1,
+ ).view(-1, bbox.size(-1))
+
+ visual_bbox = visual_bbox.repeat(final_shape[0], 1, 1)
+
+ return visual_bbox
+
+ def _get_input_shape(self, input_ids=None, inputs_embeds=None):
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ return input_ids.size()
+ elif inputs_embeds is not None:
+ return inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ bbox: Optional[torch.LongTensor] = None,
+ image: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ r"""
+ bbox (`torch.LongTensor` of shape `((batch_size, sequence_length), 4)`, *optional*):
+ Bounding boxes of each input sequence tokens. Selected in the range `[0,
+ config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
+ format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
+ y1) represents the position of the lower right corner.
+ image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`):
+ Batch of document images.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoProcessor, LayoutLMv2Model, set_seed
+ >>> from PIL import Image
+ >>> import torch
+ >>> from datasets import load_dataset
+
+ >>> set_seed(0)
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
+ >>> model = LayoutLMv2Model.from_pretrained("microsoft/layoutlmv2-base-uncased")
+
+
+ >>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa")
+ >>> image = dataset["test"][0]["image"]
+
+ >>> encoding = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**encoding)
+ >>> last_hidden_states = outputs.last_hidden_state
+
+ >>> last_hidden_states.shape
+ torch.Size([1, 342, 768])
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ input_shape = self._get_input_shape(input_ids, inputs_embeds)
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ visual_shape = list(input_shape)
+ visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
+ visual_shape = torch.Size(visual_shape)
+ # needs a new copy of input_shape for tracing. Otherwise wrong dimensions will occur
+ final_shape = list(self._get_input_shape(input_ids, inputs_embeds))
+ final_shape[1] += visual_shape[1]
+ final_shape = torch.Size(final_shape)
+
+ visual_bbox = self._calc_visual_bbox(self.config.image_feature_pool_shape, bbox, device, final_shape)
+ final_bbox = torch.cat([bbox, visual_bbox], dim=1)
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+
+ visual_attention_mask = torch.ones(visual_shape, device=device)
+ final_attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1)
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ if position_ids is None:
+ seq_length = input_shape[1]
+ position_ids = self.embeddings.position_ids[:, :seq_length]
+ position_ids = position_ids.expand(input_shape)
+
+ visual_position_ids = torch.arange(0, visual_shape[1], dtype=torch.long, device=device).repeat(
+ input_shape[0], 1
+ )
+ final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1)
+
+ if bbox is None:
+ bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
+
+ text_layout_emb = self._calc_text_embeddings(
+ input_ids=input_ids,
+ bbox=bbox,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ )
+
+ visual_emb = self._calc_img_embeddings(
+ image=image,
+ bbox=visual_bbox,
+ position_ids=visual_position_ids,
+ )
+ final_emb = torch.cat([text_layout_emb, visual_emb], dim=1)
+
+ extended_attention_mask = final_attention_mask.unsqueeze(1).unsqueeze(2)
+
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
+
+ if head_mask is not None:
+ if head_mask.dim() == 1:
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+ head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
+ elif head_mask.dim() == 2:
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
+ head_mask = head_mask.to(dtype=next(self.parameters()).dtype)
+ else:
+ head_mask = [None] * self.config.num_hidden_layers
+
+ encoder_outputs = self.encoder(
+ final_emb,
+ extended_attention_mask,
+ bbox=final_bbox,
+ position_ids=final_position_ids,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output)
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ LayoutLMv2 Model with a sequence classification head on top (a linear layer on top of the concatenation of the
+ final hidden state of the [CLS] token, average-pooled initial visual embeddings and average-pooled final visual
+ embeddings, e.g. for document image classification tasks such as the
+ [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset.
+ """
+)
+class LayoutLMv2ForSequenceClassification(LayoutLMv2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.layoutlmv2 = LayoutLMv2Model(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size * 3, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.layoutlmv2.embeddings.word_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ bbox: Optional[torch.LongTensor] = None,
+ image: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, SequenceClassifierOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `batch_size, sequence_length`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
+ Bounding boxes of each input sequence tokens. Selected in the range `[0,
+ config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
+ format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
+ y1) represents the position of the lower right corner.
+ image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`):
+ Batch of document images.
+ token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, LayoutLMv2ForSequenceClassification, set_seed
+ >>> from PIL import Image
+ >>> import torch
+ >>> from datasets import load_dataset
+
+ >>> set_seed(0)
+
+ >>> dataset = load_dataset("aharley/rvl_cdip", split="train", streaming=True)
+ >>> data = next(iter(dataset))
+ >>> image = data["image"].convert("RGB")
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
+ >>> model = LayoutLMv2ForSequenceClassification.from_pretrained(
+ ... "microsoft/layoutlmv2-base-uncased", num_labels=dataset.info.features["label"].num_classes
+ ... )
+
+ >>> encoding = processor(image, return_tensors="pt")
+ >>> sequence_label = torch.tensor([data["label"]])
+
+ >>> outputs = model(**encoding, labels=sequence_label)
+
+ >>> loss, logits = outputs.loss, outputs.logits
+ >>> predicted_idx = logits.argmax(dim=-1).item()
+ >>> predicted_answer = dataset.info.features["label"].names[4]
+ >>> predicted_idx, predicted_answer # results are not good without further fine-tuning
+ (7, 'advertisement')
+ ```
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ visual_shape = list(input_shape)
+ visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
+ visual_shape = torch.Size(visual_shape)
+ final_shape = list(input_shape)
+ final_shape[1] += visual_shape[1]
+ final_shape = torch.Size(final_shape)
+
+ visual_bbox = self.layoutlmv2._calc_visual_bbox(
+ self.config.image_feature_pool_shape, bbox, device, final_shape
+ )
+
+ visual_position_ids = torch.arange(0, visual_shape[1], dtype=torch.long, device=device).repeat(
+ input_shape[0], 1
+ )
+
+ initial_image_embeddings = self.layoutlmv2._calc_img_embeddings(
+ image=image,
+ bbox=visual_bbox,
+ position_ids=visual_position_ids,
+ )
+
+ outputs = self.layoutlmv2(
+ input_ids=input_ids,
+ bbox=bbox,
+ image=image,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+ sequence_output, final_image_embeddings = outputs[0][:, :seq_length], outputs[0][:, seq_length:]
+
+ cls_final_output = sequence_output[:, 0, :]
+
+ # average-pool the visual embeddings
+ pooled_initial_image_embeddings = initial_image_embeddings.mean(dim=1)
+ pooled_final_image_embeddings = final_image_embeddings.mean(dim=1)
+ # concatenate with cls_final_output
+ sequence_output = torch.cat(
+ [cls_final_output, pooled_initial_image_embeddings, pooled_final_image_embeddings], dim=1
+ )
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ LayoutLMv2 Model with a token classification head on top (a linear layer on top of the text part of the hidden
+ states) e.g. for sequence labeling (information extraction) tasks such as
+ [FUNSD](https://guillaumejaume.github.io/FUNSD/), [SROIE](https://rrc.cvc.uab.es/?ch=13),
+ [CORD](https://github.com/clovaai/cord) and [Kleister-NDA](https://github.com/applicaai/kleister-nda).
+ """
+)
+class LayoutLMv2ForTokenClassification(LayoutLMv2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.layoutlmv2 = LayoutLMv2Model(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.layoutlmv2.embeddings.word_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ bbox: Optional[torch.LongTensor] = None,
+ image: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, TokenClassifierOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `batch_size, sequence_length`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
+ Bounding boxes of each input sequence tokens. Selected in the range `[0,
+ config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
+ format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
+ y1) represents the position of the lower right corner.
+ image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`):
+ Batch of document images.
+ token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, LayoutLMv2ForTokenClassification, set_seed
+ >>> from PIL import Image
+ >>> from datasets import load_dataset
+
+ >>> set_seed(0)
+
+ >>> datasets = load_dataset("nielsr/funsd", split="test")
+ >>> labels = datasets.features["ner_tags"].feature.names
+ >>> id2label = {v: k for v, k in enumerate(labels)}
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
+ >>> model = LayoutLMv2ForTokenClassification.from_pretrained(
+ ... "microsoft/layoutlmv2-base-uncased", num_labels=len(labels)
+ ... )
+
+ >>> data = datasets[0]
+ >>> image = Image.open(data["image_path"]).convert("RGB")
+ >>> words = data["words"]
+ >>> boxes = data["bboxes"] # make sure to normalize your bounding boxes
+ >>> word_labels = data["ner_tags"]
+ >>> encoding = processor(
+ ... image,
+ ... words,
+ ... boxes=boxes,
+ ... word_labels=word_labels,
+ ... padding="max_length",
+ ... truncation=True,
+ ... return_tensors="pt",
+ ... )
+
+ >>> outputs = model(**encoding)
+ >>> logits, loss = outputs.logits, outputs.loss
+
+ >>> predicted_token_class_ids = logits.argmax(-1)
+ >>> predicted_tokens_classes = [id2label[t.item()] for t in predicted_token_class_ids[0]]
+ >>> predicted_tokens_classes[:5] # results are not good without further fine-tuning
+ ['I-HEADER', 'I-HEADER', 'I-QUESTION', 'I-HEADER', 'I-QUESTION']
+ ```
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.layoutlmv2(
+ input_ids=input_ids,
+ bbox=bbox,
+ image=image,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+ # only take the text part of the output representations
+ sequence_output = outputs[0][:, :seq_length]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class LayoutLMv2ForQuestionAnswering(LayoutLMv2PreTrainedModel):
+ def __init__(self, config, has_visual_segment_embedding=True):
+ r"""
+ has_visual_segment_embedding (`bool`, *optional*, defaults to `True`):
+ Whether or not to add visual segment embeddings.
+ """
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ config.has_visual_segment_embedding = has_visual_segment_embedding
+ self.layoutlmv2 = LayoutLMv2Model(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.layoutlmv2.embeddings.word_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ bbox: Optional[torch.LongTensor] = None,
+ image: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, QuestionAnsweringModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `batch_size, sequence_length`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
+ Bounding boxes of each input sequence tokens. Selected in the range `[0,
+ config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
+ format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
+ y1) represents the position of the lower right corner.
+ image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`):
+ Batch of document images.
+ token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+
+ Example:
+
+ In this example below, we give the LayoutLMv2 model an image (of texts) and ask it a question. It will give us
+ a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image).
+
+ ```python
+ >>> from transformers import AutoProcessor, LayoutLMv2ForQuestionAnswering, set_seed
+ >>> import torch
+ >>> from PIL import Image
+ >>> from datasets import load_dataset
+
+ >>> set_seed(0)
+ >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
+ >>> model = LayoutLMv2ForQuestionAnswering.from_pretrained("microsoft/layoutlmv2-base-uncased")
+
+ >>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa")
+ >>> image = dataset["test"][0]["image"]
+ >>> question = "When is coffee break?"
+ >>> encoding = processor(image, question, return_tensors="pt")
+
+ >>> outputs = model(**encoding)
+ >>> predicted_start_idx = outputs.start_logits.argmax(-1).item()
+ >>> predicted_end_idx = outputs.end_logits.argmax(-1).item()
+ >>> predicted_start_idx, predicted_end_idx
+ (30, 191)
+
+ >>> predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
+ >>> predicted_answer = processor.tokenizer.decode(predicted_answer_tokens)
+ >>> predicted_answer # results are not good without further fine-tuning
+ '44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president “ introductory remarks ” lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from'
+ ```
+
+ ```python
+ >>> target_start_index = torch.tensor([7])
+ >>> target_end_index = torch.tensor([14])
+ >>> outputs = model(**encoding, start_positions=target_start_index, end_positions=target_end_index)
+ >>> predicted_answer_span_start = outputs.start_logits.argmax(-1).item()
+ >>> predicted_answer_span_end = outputs.end_logits.argmax(-1).item()
+ >>> predicted_answer_span_start, predicted_answer_span_end
+ (30, 191)
+ ```
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.layoutlmv2(
+ input_ids=input_ids,
+ bbox=bbox,
+ image=image,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+ # only take the text part of the output representations
+ sequence_output = outputs[0][:, :seq_length]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "LayoutLMv2ForQuestionAnswering",
+ "LayoutLMv2ForSequenceClassification",
+ "LayoutLMv2ForTokenClassification",
+ "LayoutLMv2Layer",
+ "LayoutLMv2Model",
+ "LayoutLMv2PreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/processing_layoutlmv2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/processing_layoutlmv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..603cdf4df4e93e615fbb127bdf2057ae8d1d3b2c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/processing_layoutlmv2.py
@@ -0,0 +1,186 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for LayoutLMv2.
+"""
+
+import warnings
+from typing import Optional, Union
+
+from ...processing_utils import ProcessorMixin
+from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
+from ...utils import TensorType
+
+
+class LayoutLMv2Processor(ProcessorMixin):
+ r"""
+ Constructs a LayoutLMv2 processor which combines a LayoutLMv2 image processor and a LayoutLMv2 tokenizer into a
+ single processor.
+
+ [`LayoutLMv2Processor`] offers all the functionalities you need to prepare data for the model.
+
+ It first uses [`LayoutLMv2ImageProcessor`] to resize document images to a fixed size, and optionally applies OCR to
+ get words and normalized bounding boxes. These are then provided to [`LayoutLMv2Tokenizer`] or
+ [`LayoutLMv2TokenizerFast`], which turns the words and bounding boxes into token-level `input_ids`,
+ `attention_mask`, `token_type_ids`, `bbox`. Optionally, one can provide integer `word_labels`, which are turned
+ into token-level `labels` for token classification tasks (such as FUNSD, CORD).
+
+ Args:
+ image_processor (`LayoutLMv2ImageProcessor`, *optional*):
+ An instance of [`LayoutLMv2ImageProcessor`]. The image processor is a required input.
+ tokenizer (`LayoutLMv2Tokenizer` or `LayoutLMv2TokenizerFast`, *optional*):
+ An instance of [`LayoutLMv2Tokenizer`] or [`LayoutLMv2TokenizerFast`]. The tokenizer is a required input.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "LayoutLMv2ImageProcessor"
+ tokenizer_class = ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast")
+
+ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
+ feature_extractor = None
+ if "feature_extractor" in kwargs:
+ warnings.warn(
+ "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
+ " instead.",
+ FutureWarning,
+ )
+ feature_extractor = kwargs.pop("feature_extractor")
+
+ image_processor = image_processor if image_processor is not None else feature_extractor
+
+ super().__init__(image_processor, tokenizer)
+
+ def __call__(
+ self,
+ images,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ text_pair: Optional[Union[PreTokenizedInput, list[PreTokenizedInput]]] = None,
+ boxes: Optional[Union[list[list[int]], list[list[list[int]]]]] = None,
+ word_labels: Optional[Union[list[int], list[list[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ This method first forwards the `images` argument to [`~LayoutLMv2ImageProcessor.__call__`]. In case
+ [`LayoutLMv2ImageProcessor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and
+ bounding boxes along with the additional arguments to [`~LayoutLMv2Tokenizer.__call__`] and returns the output,
+ together with resized `images`. In case [`LayoutLMv2ImageProcessor`] was initialized with `apply_ocr` set to
+ `False`, it passes the words (`text`/``text_pair`) and `boxes` specified by the user along with the additional
+ arguments to [`~LayoutLMv2Tokenizer.__call__`] and returns the output, together with resized `images``.
+
+ Please refer to the docstring of the above two methods for more information.
+ """
+ # verify input
+ if self.image_processor.apply_ocr and (boxes is not None):
+ raise ValueError(
+ "You cannot provide bounding boxes if you initialized the image processor with apply_ocr set to True."
+ )
+
+ if self.image_processor.apply_ocr and (word_labels is not None):
+ raise ValueError(
+ "You cannot provide word labels if you initialized the image processor with apply_ocr set to True."
+ )
+
+ if return_overflowing_tokens is True and return_offsets_mapping is False:
+ raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.")
+
+ # first, apply the image processor
+ features = self.image_processor(images=images, return_tensors=return_tensors)
+
+ # second, apply the tokenizer
+ if text is not None and self.image_processor.apply_ocr and text_pair is None:
+ if isinstance(text, str):
+ text = [text] # add batch dimension (as the image processor always adds a batch dimension)
+ text_pair = features["words"]
+
+ encoded_inputs = self.tokenizer(
+ text=text if text is not None else features["words"],
+ text_pair=text_pair if text_pair is not None else None,
+ boxes=boxes if boxes is not None else features["boxes"],
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ return_tensors=return_tensors,
+ **kwargs,
+ )
+
+ # add pixel values
+ images = features.pop("pixel_values")
+ if return_overflowing_tokens is True:
+ images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"])
+ encoded_inputs["image"] = images
+
+ return encoded_inputs
+
+ def get_overflowing_images(self, images, overflow_to_sample_mapping):
+ # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image
+ images_with_overflow = []
+ for sample_idx in overflow_to_sample_mapping:
+ images_with_overflow.append(images[sample_idx])
+
+ if len(images_with_overflow) != len(overflow_to_sample_mapping):
+ raise ValueError(
+ "Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got"
+ f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}"
+ )
+
+ return images_with_overflow
+
+ @property
+ def model_input_names(self):
+ return ["input_ids", "bbox", "token_type_ids", "attention_mask", "image"]
+
+ @property
+ def feature_extractor_class(self):
+ warnings.warn(
+ "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
+ FutureWarning,
+ )
+ return self.image_processor_class
+
+ @property
+ def feature_extractor(self):
+ warnings.warn(
+ "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
+ FutureWarning,
+ )
+ return self.image_processor
+
+
+__all__ = ["LayoutLMv2Processor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/tokenization_layoutlmv2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/tokenization_layoutlmv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d82b5cf41041b0024134c0f1a6294c0cace824c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/tokenization_layoutlmv2.py
@@ -0,0 +1,1545 @@
+# coding=utf-8
+# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for LayoutLMv2."""
+
+import collections
+import os
+import sys
+import unicodedata
+from typing import Optional, Union
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
+from ...tokenization_utils_base import (
+ BatchEncoding,
+ EncodedInput,
+ PreTokenizedInput,
+ TextInput,
+ TextInputPair,
+ TruncationStrategy,
+)
+from ...utils import PaddingStrategy, TensorType, add_end_docstrings, logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+
+LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING = r"""
+ add_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to encode the sequences with the special tokens relative to their model.
+ padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ Activates and controls truncation. Accepts the following values:
+
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
+ to the maximum acceptable input length for the model if that argument is not provided. This will
+ truncate token by token, removing a token from the longest sequence in the pair if a pair of
+ sequences (or a batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters.
+
+ If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
+ is required by one of the truncation/padding parameters. If the model has no specific maximum input
+ length (like XLNet) truncation/padding to a maximum length will be deactivated.
+ stride (`int`, *optional*, defaults to 0):
+ If set to a number along with `max_length`, the overflowing tokens returned when
+ `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
+ returned to provide some overlap between truncated and overflowing sequences. The value of this
+ argument defines the number of overlapping tokens.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
+ the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+"""
+
+LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
+ return_token_type_ids (`bool`, *optional*):
+ Whether to return token type IDs. If left to the default, will return the token type IDs according to
+ the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are attention masks?](../glossary#attention-mask)
+ return_overflowing_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
+ of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead
+ of returning overflowing tokens.
+ return_special_tokens_mask (`bool`, *optional*, defaults to `False`):
+ Whether or not to return special tokens mask information.
+ return_offsets_mapping (`bool`, *optional*, defaults to `False`):
+ Whether or not to return `(char_start, char_end)` for each token.
+
+ This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using
+ Python's tokenizer, this method will raise `NotImplementedError`.
+ return_length (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the lengths of the encoded inputs.
+ verbose (`bool`, *optional*, defaults to `True`):
+ Whether or not to print more information and warnings.
+ **kwargs: passed to the `self.tokenize()` method
+
+ Return:
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ - **bbox** -- List of bounding boxes to be fed to a model.
+
+ - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or
+ if *"token_type_ids"* is in `self.model_input_names`).
+
+ [What are token type IDs?](../glossary#token-type-ids)
+
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`).
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ - **labels** -- List of labels to be fed to a model. (when `word_labels` is specified).
+ - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and
+ `return_overflowing_tokens=True`).
+ - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and
+ `return_overflowing_tokens=True`).
+ - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying
+ regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`).
+ - **length** -- The length of the inputs (when `return_length=True`).
+"""
+
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+table = dict.fromkeys(i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P"))
+
+
+def subfinder(mylist, pattern):
+ matches = []
+ indices = []
+ for idx, i in enumerate(range(len(mylist))):
+ if mylist[i] == pattern[0] and mylist[i : i + len(pattern)] == pattern:
+ matches.append(pattern)
+ indices.append(idx)
+ if matches:
+ return matches[0], indices[0]
+ else:
+ return None, 0
+
+
+class LayoutLMv2Tokenizer(PreTrainedTokenizer):
+ r"""
+ Construct a LayoutLMv2 tokenizer. Based on WordPiece. [`LayoutLMv2Tokenizer`] can be used to turn words, word-level
+ bounding boxes and optional word labels to token-level `input_ids`, `attention_mask`, `token_type_ids`, `bbox`, and
+ optional `labels` (for token classification).
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ [`LayoutLMv2Tokenizer`] runs end-to-end tokenization: punctuation splitting and wordpiece. It also turns the
+ word-level bounding boxes into token-level bounding boxes.
+
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=True,
+ do_basic_tokenize=True,
+ never_split=None,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ cls_token_box=[0, 0, 0, 0],
+ sep_token_box=[1000, 1000, 1000, 1000],
+ pad_token_box=[0, 0, 0, 0],
+ pad_token_label=-100,
+ only_label_first_subword=True,
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ model_max_length: int = 512,
+ additional_special_tokens: Optional[list[str]] = None,
+ **kwargs,
+ ):
+ sep_token = AddedToken(sep_token, special=True) if isinstance(sep_token, str) else sep_token
+ unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token
+ pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token
+ cls_token = AddedToken(cls_token, special=True) if isinstance(cls_token, str) else cls_token
+ mask_token = AddedToken(mask_token, special=True) if isinstance(mask_token, str) else mask_token
+
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
+ self.do_basic_tokenize = do_basic_tokenize
+ if do_basic_tokenize:
+ self.basic_tokenizer = BasicTokenizer(
+ do_lower_case=do_lower_case,
+ never_split=never_split,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ )
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
+
+ # additional properties
+ self.cls_token_box = cls_token_box
+ self.sep_token_box = sep_token_box
+ self.pad_token_box = pad_token_box
+ self.pad_token_label = pad_token_label
+ self.only_label_first_subword = only_label_first_subword
+ super().__init__(
+ do_lower_case=do_lower_case,
+ do_basic_tokenize=do_basic_tokenize,
+ never_split=never_split,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ cls_token_box=cls_token_box,
+ sep_token_box=sep_token_box,
+ pad_token_box=pad_token_box,
+ pad_token_label=pad_token_label,
+ only_label_first_subword=only_label_first_subword,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ model_max_length=model_max_length,
+ additional_special_tokens=additional_special_tokens,
+ **kwargs,
+ )
+
+ @property
+ def do_lower_case(self):
+ return self.basic_tokenizer.do_lower_case
+
+ @property
+ def vocab_size(self):
+ return len(self.vocab)
+
+ def get_vocab(self):
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+ def _tokenize(self, text):
+ split_tokens = []
+ if self.do_basic_tokenize:
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
+ # If the token is part of the never_split set
+ if token in self.basic_tokenizer.never_split:
+ split_tokens.append(token)
+ else:
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
+ else:
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
+ return split_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.ids_to_tokens.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ out_string = " ".join(tokens).replace(" ##", "").strip()
+ return out_string
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ index = 0
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ else:
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!"
+ )
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+ return (vocab_file,)
+
+ @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
+ text_pair: Optional[Union[PreTokenizedInput, list[PreTokenizedInput]]] = None,
+ boxes: Optional[Union[list[list[int]], list[list[list[int]]]]] = None,
+ word_labels: Optional[Union[list[int], list[list[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
+ sequences with word-level normalized bounding boxes and optional labels.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings
+ (words of a single example or questions of a batch of examples) or a list of list of strings (batch of
+ words).
+ text_pair (`List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence should be a list of strings
+ (pretokenized string).
+ boxes (`List[List[int]]`, `List[List[List[int]]]`):
+ Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale.
+ word_labels (`List[int]`, `List[List[int]]`, *optional*):
+ Word-level integer labels (for token classification tasks such as FUNSD, CORD).
+ """
+
+ # Input type checking for clearer error
+ def _is_valid_text_input(t):
+ if isinstance(t, str):
+ # Strings are fine
+ return True
+ elif isinstance(t, (list, tuple)):
+ # List are fine as long as they are...
+ if len(t) == 0:
+ # ... empty
+ return True
+ elif isinstance(t[0], str):
+ # ... list of strings
+ return True
+ elif isinstance(t[0], (list, tuple)):
+ # ... list with an empty list or with a list of strings
+ return len(t[0]) == 0 or isinstance(t[0][0], str)
+ else:
+ return False
+ else:
+ return False
+
+ if text_pair is not None:
+ # in case text + text_pair are provided, text = questions, text_pair = words
+ if not _is_valid_text_input(text):
+ raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ")
+ if not isinstance(text_pair, (list, tuple)):
+ raise ValueError(
+ "Words must be of type `List[str]` (single pretokenized example), "
+ "or `List[List[str]]` (batch of pretokenized examples)."
+ )
+ else:
+ # in case only text is provided => must be words
+ if not isinstance(text, (list, tuple)):
+ raise ValueError(
+ "Words must be of type `List[str]` (single pretokenized example), "
+ "or `List[List[str]]` (batch of pretokenized examples)."
+ )
+
+ if text_pair is not None:
+ is_batched = isinstance(text, (list, tuple))
+ else:
+ is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
+
+ words = text if text_pair is None else text_pair
+ if boxes is None:
+ raise ValueError("You must provide corresponding bounding boxes")
+ if is_batched:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide words and boxes for an equal amount of examples")
+ for words_example, boxes_example in zip(words, boxes):
+ if len(words_example) != len(boxes_example):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+ else:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+
+ if is_batched:
+ if text_pair is not None and len(text) != len(text_pair):
+ raise ValueError(
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
+ )
+ batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
+ is_pair = bool(text_pair is not None)
+ return self.batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+ else:
+ return self.encode_plus(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ list[TextInput],
+ list[TextInputPair],
+ list[PreTokenizedInput],
+ ],
+ is_pair: Optional[bool] = None,
+ boxes: Optional[list[list[list[int]]]] = None,
+ word_labels: Optional[Union[list[int], list[list[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ def _batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ list[TextInput],
+ list[TextInputPair],
+ list[PreTokenizedInput],
+ ],
+ is_pair: Optional[bool] = None,
+ boxes: Optional[list[list[list[int]]]] = None,
+ word_labels: Optional[list[list[int]]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers. "
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast."
+ )
+
+ batch_outputs = self._batch_prepare_for_model(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=return_tensors,
+ verbose=verbose,
+ )
+
+ return BatchEncoding(batch_outputs)
+
+ @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def _batch_prepare_for_model(
+ self,
+ batch_text_or_text_pairs,
+ is_pair: Optional[bool] = None,
+ boxes: Optional[list[list[int]]] = None,
+ word_labels: Optional[list[list[int]]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[str] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ ) -> BatchEncoding:
+ """
+ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
+ adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
+ manages a moving window (with user defined stride) for overflowing tokens.
+
+ Args:
+ batch_ids_pairs: list of tokenized input ids or input ids pairs
+ """
+
+ batch_outputs = {}
+ for idx, example in enumerate(zip(batch_text_or_text_pairs, boxes)):
+ batch_text_or_text_pair, boxes_example = example
+ outputs = self.prepare_for_model(
+ batch_text_or_text_pair[0] if is_pair else batch_text_or_text_pair,
+ batch_text_or_text_pair[1] if is_pair else None,
+ boxes_example,
+ word_labels=word_labels[idx] if word_labels is not None else None,
+ add_special_tokens=add_special_tokens,
+ padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=None, # we pad in batch afterward
+ padding_side=None, # we pad in batch afterward
+ return_attention_mask=False, # we pad in batch afterward
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=None, # We convert the whole batch to tensors at the end
+ prepend_batch_axis=False,
+ verbose=verbose,
+ )
+
+ for key, value in outputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ batch_outputs[key].append(value)
+
+ batch_outputs = self.pad(
+ batch_outputs,
+ padding=padding_strategy.value,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_attention_mask=return_attention_mask,
+ )
+
+ batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
+
+ return batch_outputs
+
+ @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING)
+ def encode(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[list[list[int]]] = None,
+ word_labels: Optional[list[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> list[int]:
+ encoded_inputs = self.encode_plus(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return encoded_inputs["input_ids"]
+
+ @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[list[list[int]]] = None,
+ word_labels: Optional[list[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated,
+ `__call__` should be used instead.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
+ text_pair (`List[str]` or `List[int]`, *optional*):
+ Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
+ list of list of strings (words of a batch of examples).
+ """
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._encode_plus(
+ text=text,
+ boxes=boxes,
+ text_pair=text_pair,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ def _encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[list[list[int]]] = None,
+ word_labels: Optional[list[int]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers. "
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast. "
+ "More information on available tokenizers at "
+ "https://github.com/huggingface/transformers/pull/2674"
+ )
+
+ return self.prepare_for_model(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding_strategy.value,
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ prepend_batch_axis=True,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ verbose=verbose,
+ )
+
+ @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def prepare_for_model(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[list[list[int]]] = None,
+ word_labels: Optional[list[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ prepend_batch_axis: bool = False,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Prepares a sequence or a pair of sequences so that it can be used by the model. It adds special tokens,
+ truncates sequences if overflowing while taking into account the special tokens and manages a moving window
+ (with user defined stride) for overflowing tokens. Please Note, for *text_pair* different than `None` and
+ *truncation_strategy = longest_first* or `True`, it is not possible to return overflowing tokens. Such a
+ combination of arguments will raise an error.
+
+ Word-level `boxes` are turned into token-level `bbox`. If provided, word-level `word_labels` are turned into
+ token-level `labels`. The word label is used for the first token of the word, while remaining tokens are
+ labeled with -100, such that they will be ignored by the loss function.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
+ text_pair (`List[str]` or `List[int]`, *optional*):
+ Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
+ list of list of strings (words of a batch of examples).
+ """
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ tokens = []
+ pair_tokens = []
+ token_boxes = []
+ pair_token_boxes = []
+ labels = []
+
+ if text_pair is None:
+ if word_labels is None:
+ # CASE 1: document image classification (training + inference) + CASE 2: token classification (inference)
+ for word, box in zip(text, boxes):
+ if len(word) < 1: # skip empty words
+ continue
+ word_tokens = self.tokenize(word)
+ tokens.extend(word_tokens)
+ token_boxes.extend([box] * len(word_tokens))
+ else:
+ # CASE 2: token classification (training)
+ for word, box, label in zip(text, boxes, word_labels):
+ if len(word) < 1: # skip empty words
+ continue
+ word_tokens = self.tokenize(word)
+ tokens.extend(word_tokens)
+ token_boxes.extend([box] * len(word_tokens))
+ if self.only_label_first_subword:
+ # Use the real label id for the first token of the word, and padding ids for the remaining tokens
+ labels.extend([label] + [self.pad_token_label] * (len(word_tokens) - 1))
+ else:
+ labels.extend([label] * len(word_tokens))
+ else:
+ # CASE 3: document visual question answering (inference)
+ # text = question
+ # text_pair = words
+ tokens = self.tokenize(text)
+ token_boxes = [self.pad_token_box for _ in range(len(tokens))]
+
+ for word, box in zip(text_pair, boxes):
+ if len(word) < 1: # skip empty words
+ continue
+ word_tokens = self.tokenize(word)
+ pair_tokens.extend(word_tokens)
+ pair_token_boxes.extend([box] * len(word_tokens))
+
+ # Create ids + pair_ids
+ ids = self.convert_tokens_to_ids(tokens)
+ pair_ids = self.convert_tokens_to_ids(pair_tokens) if pair_tokens else None
+
+ if (
+ return_overflowing_tokens
+ and truncation_strategy == TruncationStrategy.LONGEST_FIRST
+ and pair_ids is not None
+ ):
+ raise ValueError(
+ "Not possible to return overflowing tokens for pair of sequences with the "
+ "`longest_first`. Please select another truncation strategy than `longest_first`, "
+ "for instance `only_second` or `only_first`."
+ )
+
+ # Compute the total size of the returned encodings
+ pair = bool(pair_ids is not None)
+ len_ids = len(ids)
+ len_pair_ids = len(pair_ids) if pair else 0
+ total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
+
+ # Truncation: Handle max sequence length
+ overflowing_tokens = []
+ overflowing_token_boxes = []
+ overflowing_labels = []
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
+ (
+ ids,
+ token_boxes,
+ pair_ids,
+ pair_token_boxes,
+ labels,
+ overflowing_tokens,
+ overflowing_token_boxes,
+ overflowing_labels,
+ ) = self.truncate_sequences(
+ ids,
+ token_boxes,
+ pair_ids=pair_ids,
+ pair_token_boxes=pair_token_boxes,
+ labels=labels,
+ num_tokens_to_remove=total_len - max_length,
+ truncation_strategy=truncation_strategy,
+ stride=stride,
+ )
+
+ if return_token_type_ids and not add_special_tokens:
+ raise ValueError(
+ "Asking to return token_type_ids while setting add_special_tokens to False "
+ "results in an undefined behavior. Please set add_special_tokens to True or "
+ "set return_token_type_ids to None."
+ )
+
+ # Load from model defaults
+ if return_token_type_ids is None:
+ return_token_type_ids = "token_type_ids" in self.model_input_names
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ encoded_inputs = {}
+
+ if return_overflowing_tokens:
+ encoded_inputs["overflowing_tokens"] = overflowing_tokens
+ encoded_inputs["overflowing_token_boxes"] = overflowing_token_boxes
+ encoded_inputs["overflowing_labels"] = overflowing_labels
+ encoded_inputs["num_truncated_tokens"] = total_len - max_length
+
+ # Add special tokens
+ if add_special_tokens:
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
+ token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
+ token_boxes = [self.cls_token_box] + token_boxes + [self.sep_token_box]
+ if pair_token_boxes:
+ pair_token_boxes = pair_token_boxes + [self.sep_token_box]
+ if labels:
+ labels = [self.pad_token_label] + labels + [self.pad_token_label]
+ else:
+ sequence = ids + pair_ids if pair else ids
+ token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
+
+ # Build output dictionary
+ encoded_inputs["input_ids"] = sequence
+ encoded_inputs["bbox"] = token_boxes + pair_token_boxes
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = token_type_ids
+ if return_special_tokens_mask:
+ if add_special_tokens:
+ encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
+ else:
+ encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
+
+ if labels:
+ encoded_inputs["labels"] = labels
+
+ # Check lengths
+ self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
+
+ # Padding
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
+ encoded_inputs = self.pad(
+ encoded_inputs,
+ max_length=max_length,
+ padding=padding_strategy.value,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_attention_mask=return_attention_mask,
+ )
+
+ if return_length:
+ encoded_inputs["length"] = len(encoded_inputs["input_ids"])
+
+ batch_outputs = BatchEncoding(
+ encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
+ )
+
+ return batch_outputs
+
+ def truncate_sequences(
+ self,
+ ids: list[int],
+ token_boxes: list[list[int]],
+ pair_ids: Optional[list[int]] = None,
+ pair_token_boxes: Optional[list[list[int]]] = None,
+ labels: Optional[list[int]] = None,
+ num_tokens_to_remove: int = 0,
+ truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
+ stride: int = 0,
+ ) -> tuple[list[int], list[int], list[int]]:
+ """
+ Truncates a sequence pair in-place following the strategy.
+
+ Args:
+ ids (`List[int]`):
+ Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
+ `convert_tokens_to_ids` methods.
+ token_boxes (`List[List[int]]`):
+ Bounding boxes of the first sequence.
+ pair_ids (`List[int]`, *optional*):
+ Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
+ and `convert_tokens_to_ids` methods.
+ pair_token_boxes (`List[List[int]]`, *optional*):
+ Bounding boxes of the second sequence.
+ labels (`List[int]`, *optional*):
+ Labels of the first sequence (for token classification tasks).
+ num_tokens_to_remove (`int`, *optional*, defaults to 0):
+ Number of tokens to remove using the truncation strategy.
+ truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ The strategy to follow for truncation. Can be:
+
+ - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will truncate
+ token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a
+ batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater
+ than the model maximum admissible input size).
+ stride (`int`, *optional*, defaults to 0):
+ If set to a positive number, the overflowing tokens returned will contain some tokens from the main
+ sequence returned. The value of this argument defines the number of additional tokens.
+
+ Returns:
+ `Tuple[List[int], List[int], List[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of
+ overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair
+ of sequences (or a batch of pairs) is provided.
+ """
+ if num_tokens_to_remove <= 0:
+ return ids, token_boxes, pair_ids, pair_token_boxes, labels, [], [], []
+
+ if not isinstance(truncation_strategy, TruncationStrategy):
+ truncation_strategy = TruncationStrategy(truncation_strategy)
+
+ overflowing_tokens = []
+ overflowing_token_boxes = []
+ overflowing_labels = []
+ if truncation_strategy == TruncationStrategy.ONLY_FIRST or (
+ truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None
+ ):
+ if len(ids) > num_tokens_to_remove:
+ window_len = min(len(ids), stride + num_tokens_to_remove)
+ overflowing_tokens = ids[-window_len:]
+ overflowing_token_boxes = token_boxes[-window_len:]
+ overflowing_labels = labels[-window_len:]
+ ids = ids[:-num_tokens_to_remove]
+ token_boxes = token_boxes[:-num_tokens_to_remove]
+ labels = labels[:-num_tokens_to_remove]
+ else:
+ error_msg = (
+ f"We need to remove {num_tokens_to_remove} to truncate the input "
+ f"but the first sequence has a length {len(ids)}. "
+ )
+ if truncation_strategy == TruncationStrategy.ONLY_FIRST:
+ error_msg = (
+ error_msg + "Please select another truncation strategy than "
+ f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
+ )
+ logger.error(error_msg)
+ elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
+ logger.warning(
+ "Be aware, overflowing tokens are not returned for the setting you have chosen,"
+ f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
+ "truncation strategy. So the returned list will always be empty even if some "
+ "tokens have been removed."
+ )
+ for _ in range(num_tokens_to_remove):
+ if pair_ids is None or len(ids) > len(pair_ids):
+ ids = ids[:-1]
+ token_boxes = token_boxes[:-1]
+ labels = labels[:-1]
+ else:
+ pair_ids = pair_ids[:-1]
+ pair_token_boxes = pair_token_boxes[:-1]
+ elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
+ if len(pair_ids) > num_tokens_to_remove:
+ window_len = min(len(pair_ids), stride + num_tokens_to_remove)
+ overflowing_tokens = pair_ids[-window_len:]
+ overflowing_token_boxes = pair_token_boxes[-window_len:]
+ pair_ids = pair_ids[:-num_tokens_to_remove]
+ pair_token_boxes = pair_token_boxes[:-num_tokens_to_remove]
+ else:
+ logger.error(
+ f"We need to remove {num_tokens_to_remove} to truncate the input "
+ f"but the second sequence has a length {len(pair_ids)}. "
+ f"Please select another truncation strategy than {truncation_strategy}, "
+ "for instance 'longest_first' or 'only_first'."
+ )
+
+ return (
+ ids,
+ token_boxes,
+ pair_ids,
+ pair_token_boxes,
+ labels,
+ overflowing_tokens,
+ overflowing_token_boxes,
+ overflowing_labels,
+ )
+
+ def _pad(
+ self,
+ encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta).
+ padding_side:
+ The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+ Default value is picked from the class attribute of the same name.
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if return_attention_mask and "attention_mask" not in encoded_inputs:
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+ padding_side = padding_side if padding_side is not None else self.padding_side
+ if padding_side == "right":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
+ )
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
+ elif padding_side == "left":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
+ "token_type_ids"
+ ]
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"]
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"]
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+ else:
+ raise ValueError("Invalid padding strategy:" + str(padding_side))
+
+ return encoded_inputs
+
+
+# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
+class BasicTokenizer:
+ """
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
+
+ Args:
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ do_split_on_punc (`bool`, *optional*, defaults to `True`):
+ In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
+ the full context of the words, such as contractions.
+ """
+
+ def __init__(
+ self,
+ do_lower_case=True,
+ never_split=None,
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ do_split_on_punc=True,
+ ):
+ if never_split is None:
+ never_split = []
+ self.do_lower_case = do_lower_case
+ self.never_split = set(never_split)
+ self.tokenize_chinese_chars = tokenize_chinese_chars
+ self.strip_accents = strip_accents
+ self.do_split_on_punc = do_split_on_punc
+
+ def tokenize(self, text, never_split=None):
+ """
+ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
+
+ Args:
+ never_split (`List[str]`, *optional*)
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
+ """
+ # union() returns a new set by concatenating the two sets.
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
+ text = self._clean_text(text)
+
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ if self.tokenize_chinese_chars:
+ text = self._tokenize_chinese_chars(text)
+ # prevents treating the same character with different unicode codepoints as different characters
+ unicode_normalized_text = unicodedata.normalize("NFC", text)
+ orig_tokens = whitespace_tokenize(unicode_normalized_text)
+ split_tokens = []
+ for token in orig_tokens:
+ if token not in never_split:
+ if self.do_lower_case:
+ token = token.lower()
+ if self.strip_accents is not False:
+ token = self._run_strip_accents(token)
+ elif self.strip_accents:
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text, never_split=None):
+ """Splits punctuation on a piece of text."""
+ if not self.do_split_on_punc or (never_split is not None and text in never_split):
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if (
+ (cp >= 0x4E00 and cp <= 0x9FFF)
+ or (cp >= 0x3400 and cp <= 0x4DBF)
+ or (cp >= 0x20000 and cp <= 0x2A6DF)
+ or (cp >= 0x2A700 and cp <= 0x2B73F)
+ or (cp >= 0x2B740 and cp <= 0x2B81F)
+ or (cp >= 0x2B820 and cp <= 0x2CEAF)
+ or (cp >= 0xF900 and cp <= 0xFAFF)
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
+ ):
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
+class WordpieceTokenizer:
+ """Runs WordPiece tokenization."""
+
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
+ tokenization using the given vocabulary.
+
+ For example, `input = "unaffable"` will return as output `["un", "##aff", "##able"]`.
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through *BasicTokenizer*.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ return output_tokens
+
+
+__all__ = ["LayoutLMv2Tokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e324ee0b8fe971b06cd102eaa9930990a1f5b99
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py
@@ -0,0 +1,789 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Fast tokenization class for LayoutLMv2. It overwrites 2 methods of the slow tokenizer class, namely _batch_encode_plus
+and _encode_plus, in which the Rust tokenizer is used.
+"""
+
+import json
+from typing import Optional, Union
+
+from tokenizers import normalizers
+
+from ...tokenization_utils_base import (
+ BatchEncoding,
+ EncodedInput,
+ PaddingStrategy,
+ PreTokenizedInput,
+ TensorType,
+ TextInput,
+ TextInputPair,
+ TruncationStrategy,
+)
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import add_end_docstrings, logging
+from .tokenization_layoutlmv2 import (
+ LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING,
+ LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
+ LayoutLMv2Tokenizer,
+)
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
+
+
+class LayoutLMv2TokenizerFast(PreTrainedTokenizerFast):
+ r"""
+ Construct a "fast" LayoutLMv2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ cls_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [CLS] token.
+ sep_token_box (`List[int]`, *optional*, defaults to `[1000, 1000, 1000, 1000]`):
+ The bounding box to use for the special [SEP] token.
+ pad_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [PAD] token.
+ pad_token_label (`int`, *optional*, defaults to -100):
+ The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's
+ CrossEntropyLoss.
+ only_label_first_subword (`bool`, *optional*, defaults to `True`):
+ Whether or not to only label the first subword, in case word labels are provided.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this
+ issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original LayoutLMv2).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = LayoutLMv2Tokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ do_lower_case=True,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ cls_token_box=[0, 0, 0, 0],
+ sep_token_box=[1000, 1000, 1000, 1000],
+ pad_token_box=[0, 0, 0, 0],
+ pad_token_label=-100,
+ only_label_first_subword=True,
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file,
+ tokenizer_file=tokenizer_file,
+ do_lower_case=do_lower_case,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ cls_token_box=cls_token_box,
+ sep_token_box=sep_token_box,
+ pad_token_box=pad_token_box,
+ pad_token_label=pad_token_label,
+ only_label_first_subword=only_label_first_subword,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ **kwargs,
+ )
+
+ pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
+ if (
+ pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
+ or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
+ ):
+ pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
+ pre_tok_state["lowercase"] = do_lower_case
+ pre_tok_state["strip_accents"] = strip_accents
+ self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
+
+ self.do_lower_case = do_lower_case
+
+ # additional properties
+ self.cls_token_box = cls_token_box
+ self.sep_token_box = sep_token_box
+ self.pad_token_box = pad_token_box
+ self.pad_token_label = pad_token_label
+ self.only_label_first_subword = only_label_first_subword
+
+ @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
+ text_pair: Optional[Union[PreTokenizedInput, list[PreTokenizedInput]]] = None,
+ boxes: Optional[Union[list[list[int]], list[list[list[int]]]]] = None,
+ word_labels: Optional[Union[list[int], list[list[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
+ sequences with word-level normalized bounding boxes and optional labels.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings
+ (words of a single example or questions of a batch of examples) or a list of list of strings (batch of
+ words).
+ text_pair (`List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence should be a list of strings
+ (pretokenized string).
+ boxes (`List[List[int]]`, `List[List[List[int]]]`):
+ Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale.
+ word_labels (`List[int]`, `List[List[int]]`, *optional*):
+ Word-level integer labels (for token classification tasks such as FUNSD, CORD).
+ """
+
+ # Input type checking for clearer error
+ def _is_valid_text_input(t):
+ if isinstance(t, str):
+ # Strings are fine
+ return True
+ elif isinstance(t, (list, tuple)):
+ # List are fine as long as they are...
+ if len(t) == 0:
+ # ... empty
+ return True
+ elif isinstance(t[0], str):
+ # ... list of strings
+ return True
+ elif isinstance(t[0], (list, tuple)):
+ # ... list with an empty list or with a list of strings
+ return len(t[0]) == 0 or isinstance(t[0][0], str)
+ else:
+ return False
+ else:
+ return False
+
+ if text_pair is not None:
+ # in case text + text_pair are provided, text = questions, text_pair = words
+ if not _is_valid_text_input(text):
+ raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ")
+ if not isinstance(text_pair, (list, tuple)):
+ raise ValueError(
+ "Words must be of type `List[str]` (single pretokenized example), "
+ "or `List[List[str]]` (batch of pretokenized examples)."
+ )
+ else:
+ # in case only text is provided => must be words
+ if not isinstance(text, (list, tuple)):
+ raise ValueError(
+ "Words must be of type `List[str]` (single pretokenized example), "
+ "or `List[List[str]]` (batch of pretokenized examples)."
+ )
+
+ if text_pair is not None:
+ is_batched = isinstance(text, (list, tuple))
+ else:
+ is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
+
+ words = text if text_pair is None else text_pair
+ if boxes is None:
+ raise ValueError("You must provide corresponding bounding boxes")
+ if is_batched:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide words and boxes for an equal amount of examples")
+ for words_example, boxes_example in zip(words, boxes):
+ if len(words_example) != len(boxes_example):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+ else:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+
+ if is_batched:
+ if text_pair is not None and len(text) != len(text_pair):
+ raise ValueError(
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
+ )
+ batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
+ is_pair = bool(text_pair is not None)
+ return self.batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+ else:
+ return self.encode_plus(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ list[TextInput],
+ list[TextInputPair],
+ list[PreTokenizedInput],
+ ],
+ is_pair: Optional[bool] = None,
+ boxes: Optional[list[list[list[int]]]] = None,
+ word_labels: Optional[Union[list[int], list[list[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> list[str]:
+ batched_input = [(text, pair)] if pair else [text]
+ encodings = self._tokenizer.encode_batch(
+ batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs
+ )
+
+ return encodings[0].tokens
+
+ @add_end_docstrings(LAYOUTLMV2_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[list[list[int]]] = None,
+ word_labels: Optional[list[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated,
+ `__call__` should be used instead.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
+ text_pair (`List[str]` or `List[int]`, *optional*):
+ Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
+ list of list of strings (words of a batch of examples).
+ """
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._encode_plus(
+ text=text,
+ boxes=boxes,
+ text_pair=text_pair,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ def _batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ list[TextInput],
+ list[TextInputPair],
+ list[PreTokenizedInput],
+ ],
+ is_pair: Optional[bool] = None,
+ boxes: Optional[list[list[list[int]]]] = None,
+ word_labels: Optional[list[list[int]]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[str] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ ) -> BatchEncoding:
+ if not isinstance(batch_text_or_text_pairs, list):
+ raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})")
+
+ # Set the truncation and padding strategy and restore the initial configuration
+ self.set_truncation_and_padding(
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ )
+
+ if is_pair:
+ batch_text_or_text_pairs = [(text.split(), text_pair) for text, text_pair in batch_text_or_text_pairs]
+
+ encodings = self._tokenizer.encode_batch(
+ batch_text_or_text_pairs,
+ add_special_tokens=add_special_tokens,
+ is_pretokenized=True, # we set this to True as LayoutLMv2 always expects pretokenized inputs
+ )
+
+ # Convert encoding to dict
+ # `Tokens` has type: Tuple[
+ # List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
+ # List[EncodingFast]
+ # ]
+ # with nested dimensions corresponding to batch, overflows, sequence length
+ tokens_and_encodings = [
+ self._convert_encoding(
+ encoding=encoding,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=True
+ if word_labels is not None
+ else return_offsets_mapping, # we use offsets to create the labels
+ return_length=return_length,
+ verbose=verbose,
+ )
+ for encoding in encodings
+ ]
+
+ # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
+ # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
+ # (we say ~ because the number of overflow varies with the example in the batch)
+ #
+ # To match each overflowing sample with the original sample in the batch
+ # we add an overflow_to_sample_mapping array (see below)
+ sanitized_tokens = {}
+ for key in tokens_and_encodings[0][0]:
+ stack = [e for item, _ in tokens_and_encodings for e in item[key]]
+ sanitized_tokens[key] = stack
+ sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
+
+ # If returning overflowing tokens, we need to return a mapping
+ # from the batch idx to the original sample
+ if return_overflowing_tokens:
+ overflow_to_sample_mapping = []
+ for i, (toks, _) in enumerate(tokens_and_encodings):
+ overflow_to_sample_mapping += [i] * len(toks["input_ids"])
+ sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
+
+ for input_ids in sanitized_tokens["input_ids"]:
+ self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
+
+ # create the token boxes
+ token_boxes = []
+ for batch_index in range(len(sanitized_tokens["input_ids"])):
+ if return_overflowing_tokens:
+ original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
+ else:
+ original_index = batch_index
+ token_boxes_example = []
+ for id, sequence_id, word_id in zip(
+ sanitized_tokens["input_ids"][batch_index],
+ sanitized_encodings[batch_index].sequence_ids,
+ sanitized_encodings[batch_index].word_ids,
+ ):
+ if word_id is not None:
+ if is_pair and sequence_id == 0:
+ token_boxes_example.append(self.pad_token_box)
+ else:
+ token_boxes_example.append(boxes[original_index][word_id])
+ else:
+ if id == self.cls_token_id:
+ token_boxes_example.append(self.cls_token_box)
+ elif id == self.sep_token_id:
+ token_boxes_example.append(self.sep_token_box)
+ elif id == self.pad_token_id:
+ token_boxes_example.append(self.pad_token_box)
+ else:
+ raise ValueError("Id not recognized")
+ token_boxes.append(token_boxes_example)
+
+ sanitized_tokens["bbox"] = token_boxes
+
+ # optionally, create the labels
+ if word_labels is not None:
+ labels = []
+ for batch_index in range(len(sanitized_tokens["input_ids"])):
+ if return_overflowing_tokens:
+ original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
+ else:
+ original_index = batch_index
+ labels_example = []
+ for id, offset, word_id in zip(
+ sanitized_tokens["input_ids"][batch_index],
+ sanitized_tokens["offset_mapping"][batch_index],
+ sanitized_encodings[batch_index].word_ids,
+ ):
+ if word_id is not None:
+ if self.only_label_first_subword:
+ if offset[0] == 0:
+ # Use the real label id for the first token of the word, and padding ids for the remaining tokens
+ labels_example.append(word_labels[original_index][word_id])
+ else:
+ labels_example.append(self.pad_token_label)
+ else:
+ labels_example.append(word_labels[original_index][word_id])
+ else:
+ labels_example.append(self.pad_token_label)
+ labels.append(labels_example)
+
+ sanitized_tokens["labels"] = labels
+ # finally, remove offsets if the user didn't want them
+ if not return_offsets_mapping:
+ del sanitized_tokens["offset_mapping"]
+
+ return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
+
+ def _encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[list[list[int]]] = None,
+ word_labels: Optional[list[int]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[bool] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ # make it a batched input
+ # 2 options:
+ # 1) only text, in case text must be a list of str
+ # 2) text + text_pair, in which case text = str and text_pair a list of str
+ batched_input = [(text, text_pair)] if text_pair else [text]
+ batched_boxes = [boxes]
+ batched_word_labels = [word_labels] if word_labels is not None else None
+ batched_output = self._batch_encode_plus(
+ batched_input,
+ is_pair=bool(text_pair is not None),
+ boxes=batched_boxes,
+ word_labels=batched_word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ # Return tensor is None, then we can remove the leading batch axis
+ # Overflowing tokens are returned as a batch of output so we keep them in this case
+ if return_tensors is None and not return_overflowing_tokens:
+ batched_output = BatchEncoding(
+ {
+ key: value[0] if len(value) > 0 and isinstance(value[0], list) else value
+ for key, value in batched_output.items()
+ },
+ batched_output.encodings,
+ )
+
+ self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
+
+ return batched_output
+
+ def _pad(
+ self,
+ encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta).
+ padding_side:
+ The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+ Default value is picked from the class attribute of the same name.
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if return_attention_mask and "attention_mask" not in encoded_inputs:
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+ padding_side = padding_side if padding_side is not None else self.padding_side
+ if padding_side == "right":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
+ )
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
+ elif padding_side == "left":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
+ "token_type_ids"
+ ]
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"]
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"]
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+ else:
+ raise ValueError("Invalid padding strategy:" + str(padding_side))
+
+ return encoded_inputs
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+
+ if token_ids_1:
+ output += token_ids_1 + [self.sep_token_id]
+
+ return output
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+
+__all__ = ["LayoutLMv2TokenizerFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..190e4e4329419f92da2141c2027f8212d6178e6f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_lightglue import *
+ from .image_processing_lightglue import *
+ from .modeling_lightglue import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/configuration_lightglue.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/configuration_lightglue.py
new file mode 100644
index 0000000000000000000000000000000000000000..90e8d41b45156932278872ddd61432e18330db7c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/configuration_lightglue.py
@@ -0,0 +1,155 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_lightglue.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...configuration_utils import PretrainedConfig
+from ..auto import CONFIG_MAPPING, AutoConfig
+from ..superpoint import SuperPointConfig
+
+
+class LightGlueConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LightGlueForKeypointMatching`]. It is used to
+ instantiate a LightGlue model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the LightGlue
+ [ETH-CVG/lightglue_superpoint](https://huggingface.co/ETH-CVG/lightglue_superpoint) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`):
+ The config object or dictionary of the keypoint detector.
+ descriptor_dim (`int`, *optional*, defaults to 256):
+ The dimension of the descriptors.
+ num_hidden_layers (`int`, *optional*, defaults to 9):
+ The number of self and cross attention layers.
+ num_attention_heads (`int`, *optional*, defaults to 4):
+ The number of heads in the multi-head attention.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ depth_confidence (`float`, *optional*, defaults to 0.95):
+ The confidence threshold used to perform early stopping
+ width_confidence (`float`, *optional*, defaults to 0.99):
+ The confidence threshold used to prune points
+ filter_threshold (`float`, *optional*, defaults to 0.1):
+ The confidence threshold used to filter matches
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The activation function to be used in the hidden layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ attention_bias (`bool`, *optional*, defaults to `True`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
+ Whether to trust remote code when using other models than SuperPoint as keypoint detector.
+
+ Examples:
+ ```python
+ >>> from transformers import LightGlueConfig, LightGlueForKeypointMatching
+
+ >>> # Initializing a LightGlue style configuration
+ >>> configuration = LightGlueConfig()
+
+ >>> # Initializing a model from the LightGlue style configuration
+ >>> model = LightGlueForKeypointMatching(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "lightglue"
+ sub_configs = {"keypoint_detector_config": AutoConfig}
+
+ def __init__(
+ self,
+ keypoint_detector_config: SuperPointConfig = None,
+ descriptor_dim: int = 256,
+ num_hidden_layers: int = 9,
+ num_attention_heads: int = 4,
+ num_key_value_heads=None,
+ depth_confidence: float = 0.95,
+ width_confidence: float = 0.99,
+ filter_threshold: float = 0.1,
+ initializer_range: float = 0.02,
+ hidden_act: str = "gelu",
+ attention_dropout=0.0,
+ attention_bias=True,
+ trust_remote_code: bool = False,
+ **kwargs,
+ ):
+ # LightGlue can be used with other models than SuperPoint as keypoint detector
+ # We provide the trust_remote_code argument to allow the use of other models
+ # that are not registered in the CONFIG_MAPPING dictionary (for example DISK)
+ self.trust_remote_code = trust_remote_code
+
+ if descriptor_dim % num_attention_heads != 0:
+ raise ValueError("descriptor_dim % num_heads is different from zero")
+
+ self.descriptor_dim = descriptor_dim
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+
+ self.depth_confidence = depth_confidence
+ self.width_confidence = width_confidence
+ self.filter_threshold = filter_threshold
+ self.initializer_range = initializer_range
+
+ # Keypoint Detector is forced into eager attention mode because SuperPoint does not have Attention
+ # See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153
+ if isinstance(keypoint_detector_config, dict):
+ keypoint_detector_config["model_type"] = keypoint_detector_config.get("model_type", "superpoint")
+ if keypoint_detector_config["model_type"] not in CONFIG_MAPPING:
+ keypoint_detector_config = AutoConfig.from_pretrained(
+ keypoint_detector_config["_name_or_path"], trust_remote_code=self.trust_remote_code
+ )
+ else:
+ keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]](
+ **keypoint_detector_config, attn_implementation="eager"
+ )
+
+ if keypoint_detector_config is None:
+ keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager")
+
+ self.keypoint_detector_config = keypoint_detector_config
+
+ self.hidden_size = descriptor_dim
+ self.intermediate_size = descriptor_dim * 2
+ self.hidden_act = hidden_act
+ self.attention_dropout = attention_dropout
+ self.attention_bias = attention_bias
+ super().__init__(**kwargs)
+
+
+__all__ = ["LightGlueConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/image_processing_lightglue.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/image_processing_lightglue.py
new file mode 100644
index 0000000000000000000000000000000000000000..400475b76c77e6f078f53ad31fe4dcffde5bef8e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/image_processing_lightglue.py
@@ -0,0 +1,526 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_lightglue.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import warnings
+from typing import Optional, Union
+
+import numpy as np
+import torch
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import resize, to_channel_dimension_format
+from ...image_utils import (
+ ChannelDimension,
+ ImageInput,
+ ImageType,
+ PILImageResampling,
+ get_image_type,
+ infer_channel_dimension_format,
+ is_pil_image,
+ is_scaled_image,
+ is_valid_image,
+ is_vision_available,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, is_matplotlib_available, logging, requires_backends
+from ...utils.import_utils import requires
+from .modeling_lightglue import LightGlueKeypointMatchingOutput
+
+
+if is_vision_available():
+ from PIL import Image, ImageDraw
+
+if is_vision_available():
+ import PIL
+
+logger = logging.get_logger(__name__)
+
+
+def is_grayscale(
+ image: np.ndarray,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+):
+ if input_data_format == ChannelDimension.FIRST:
+ if image.shape[0] == 1:
+ return True
+ return np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...])
+ elif input_data_format == ChannelDimension.LAST:
+ if image.shape[-1] == 1:
+ return True
+ return np.all(image[..., 0] == image[..., 1]) and np.all(image[..., 1] == image[..., 2])
+
+
+def convert_to_grayscale(
+ image: ImageInput,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> ImageInput:
+ """
+ Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch
+ and tensorflow grayscale conversion
+
+ This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each
+ channel, because of an issue that is discussed in :
+ https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
+
+ Args:
+ image (Image):
+ The image to convert.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image.
+ """
+ requires_backends(convert_to_grayscale, ["vision"])
+
+ if isinstance(image, np.ndarray):
+ if is_grayscale(image, input_data_format=input_data_format):
+ return image
+ if input_data_format == ChannelDimension.FIRST:
+ gray_image = image[0, ...] * 0.2989 + image[1, ...] * 0.5870 + image[2, ...] * 0.1140
+ gray_image = np.stack([gray_image] * 3, axis=0)
+ elif input_data_format == ChannelDimension.LAST:
+ gray_image = image[..., 0] * 0.2989 + image[..., 1] * 0.5870 + image[..., 2] * 0.1140
+ gray_image = np.stack([gray_image] * 3, axis=-1)
+ return gray_image
+
+ if not isinstance(image, PIL.Image.Image):
+ return image
+
+ image = image.convert("L")
+ return image
+
+
+def validate_and_format_image_pairs(images: ImageInput):
+ error_message = (
+ "Input images must be a one of the following :",
+ " - A pair of PIL images.",
+ " - A pair of 3D arrays.",
+ " - A list of pairs of PIL images.",
+ " - A list of pairs of 3D arrays.",
+ )
+
+ def _is_valid_image(image):
+ """images is a PIL Image or a 3D array."""
+ return is_pil_image(image) or (
+ is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3
+ )
+
+ if isinstance(images, list):
+ if len(images) == 2 and all((_is_valid_image(image)) for image in images):
+ return images
+ if all(
+ isinstance(image_pair, list)
+ and len(image_pair) == 2
+ and all(_is_valid_image(image) for image in image_pair)
+ for image_pair in images
+ ):
+ return [image for image_pair in images for image in image_pair]
+ raise ValueError(error_message)
+
+
+@requires(backends=("torch",))
+class LightGlueImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a LightGlue image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden
+ by `do_resize` in the `preprocess` method.
+ size (`dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`):
+ Resolution of the output image after `resize` is applied. Only has an effect if `do_resize` is set to
+ `True`. Can be overridden by `size` in the `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_grayscale (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: float = 1 / 255,
+ do_grayscale: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 480, "width": 640}
+ size = get_size_dict(size, default_to_square=False)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_grayscale = do_grayscale
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ):
+ """
+ Resize an image.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary of the form `{"height": int, "width": int}`, specifying the size of the output image.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the output image. If not provided, it will be inferred from the input
+ image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ size = get_size_dict(size, default_to_square=False)
+
+ return resize(
+ image,
+ size=(size["height"], size["width"]),
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def preprocess(
+ self,
+ images,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_grayscale: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image pairs to preprocess. Expects either a list of 2 images or a list of list of 2 images list with
+ pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set
+ `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
+ is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
+ image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
+ `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`):
+ Whether to convert the image to grayscale.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_grayscale = do_grayscale if do_grayscale is not None else self.do_grayscale
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+
+ # Validate and convert the input images into a flattened list of images for all subsequent processing steps.
+ images = validate_and_format_image_pairs(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ )
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if is_scaled_image(images[0]) and do_rescale:
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ all_images = []
+ for image in images:
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_grayscale:
+ image = convert_to_grayscale(image, input_data_format=input_data_format)
+
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ all_images.append(image)
+
+ # Convert back the flattened list of images into a list of pairs of images.
+ image_pairs = [all_images[i : i + 2] for i in range(0, len(all_images), 2)]
+
+ data = {"pixel_values": image_pairs}
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def post_process_keypoint_matching(
+ self,
+ outputs: LightGlueKeypointMatchingOutput,
+ target_sizes: Union[TensorType, list[tuple]],
+ threshold: float = 0.0,
+ ) -> list[dict[str, torch.Tensor]]:
+ """
+ Converts the raw output of [`KeypointMatchingOutput`] into lists of keypoints, scores and descriptors
+ with coordinates absolute to the original image sizes.
+ Args:
+ outputs ([`KeypointMatchingOutput`]):
+ Raw outputs of the model.
+ target_sizes (`torch.Tensor` or `list[tuple[tuple[int, int]]]`, *optional*):
+ Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`tuple[int, int]`) containing the
+ target size `(height, width)` of each image in the batch. This must be the original image size (before
+ any processing).
+ threshold (`float`, *optional*, defaults to 0.0):
+ Threshold to filter out the matches with low scores.
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
+ of the pair, the matching scores and the matching indices.
+ """
+ if outputs.mask.shape[0] != len(target_sizes):
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
+ if not all(len(target_size) == 2 for target_size in target_sizes):
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
+
+ if isinstance(target_sizes, list):
+ image_pair_sizes = torch.tensor(target_sizes, device=outputs.mask.device)
+ else:
+ if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:
+ raise ValueError(
+ "Each element of target_sizes must contain the size (h, w) of each image of the batch"
+ )
+ image_pair_sizes = target_sizes
+
+ keypoints = outputs.keypoints.clone()
+ keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
+ keypoints = keypoints.to(torch.int32)
+
+ results = []
+ for mask_pair, keypoints_pair, matches, scores in zip(
+ outputs.mask, keypoints, outputs.matches[:, 0], outputs.matching_scores[:, 0]
+ ):
+ mask0 = mask_pair[0] > 0
+ mask1 = mask_pair[1] > 0
+ keypoints0 = keypoints_pair[0][mask0]
+ keypoints1 = keypoints_pair[1][mask1]
+ matches0 = matches[mask0]
+ scores0 = scores[mask0]
+
+ # Filter out matches with low scores
+ valid_matches = torch.logical_and(scores0 > threshold, matches0 > -1)
+
+ matched_keypoints0 = keypoints0[valid_matches]
+ matched_keypoints1 = keypoints1[matches0[valid_matches]]
+ matching_scores = scores0[valid_matches]
+
+ results.append(
+ {
+ "keypoints0": matched_keypoints0,
+ "keypoints1": matched_keypoints1,
+ "matching_scores": matching_scores,
+ }
+ )
+
+ return results
+
+ def visualize_keypoint_matching(
+ self,
+ images: ImageInput,
+ keypoint_matching_output: list[dict[str, torch.Tensor]],
+ ) -> list["Image.Image"]:
+ """
+ Plots the image pairs side by side with the detected keypoints as well as the matching between them.
+
+ Args:
+ images (`ImageInput`):
+ Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
+ images or a list of list of 2 images list with pixel values ranging from 0 to 255.
+ keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
+ A post processed keypoint matching output
+
+ Returns:
+ `List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
+ keypoints as well as the matching between them.
+ """
+ images = validate_and_format_image_pairs(images)
+ images = [to_numpy_array(image) for image in images]
+ image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
+
+ results = []
+ for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
+ height0, width0 = image_pair[0].shape[:2]
+ height1, width1 = image_pair[1].shape[:2]
+ plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
+ plot_image[:height0, :width0] = image_pair[0]
+ plot_image[:height1, width0:] = image_pair[1]
+
+ plot_image_pil = Image.fromarray(plot_image)
+ draw = ImageDraw.Draw(plot_image_pil)
+
+ keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
+ keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
+ for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
+ keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
+ ):
+ color = self._get_color(matching_score)
+ draw.line(
+ (keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
+ fill=color,
+ width=3,
+ )
+ draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
+ draw.ellipse(
+ (keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
+ fill="black",
+ )
+
+ results.append(plot_image_pil)
+ return results
+
+ def _get_color(self, score):
+ """Maps a score to a color."""
+ r = int(255 * (1 - score))
+ g = int(255 * score)
+ b = 0
+ return (r, g, b)
+
+ def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
+ """
+ Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
+ matplotlib to be installed.
+
+ .. deprecated::
+ `plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
+
+ Args:
+ images (`ImageInput`):
+ Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
+ a list of list of 2 images list with pixel values ranging from 0 to 255.
+ keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
+ Raw outputs of the model.
+ """
+ warnings.warn(
+ "`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
+ "Use `visualize_keypoint_matching` instead.",
+ FutureWarning,
+ )
+
+ if is_matplotlib_available():
+ import matplotlib.pyplot as plt
+ else:
+ raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method")
+
+ images = validate_and_format_image_pairs(images)
+ images = [to_numpy_array(image) for image in images]
+ image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
+
+ for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
+ height0, width0 = image_pair[0].shape[:2]
+ height1, width1 = image_pair[1].shape[:2]
+ plot_image = np.zeros((max(height0, height1), width0 + width1, 3))
+ plot_image[:height0, :width0] = image_pair[0] / 255.0
+ plot_image[:height1, width0:] = image_pair[1] / 255.0
+ plt.imshow(plot_image)
+ plt.axis("off")
+
+ keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
+ keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
+ for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
+ keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
+ ):
+ plt.plot(
+ [keypoint0_x, keypoint1_x + width0],
+ [keypoint0_y, keypoint1_y],
+ color=plt.get_cmap("RdYlGn")(matching_score.item()),
+ alpha=0.9,
+ linewidth=0.5,
+ )
+ plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2)
+ plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2)
+ plt.show()
+
+
+__all__ = ["LightGlueImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/modeling_lightglue.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/modeling_lightglue.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e9faa3e4e0439750057ff2e1cab19dff16c9867
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/modeling_lightglue.py
@@ -0,0 +1,920 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_lightglue.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn.utils.rnn import pad_sequence
+
+from ...activations import ACT2FN
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import ModelOutput, TransformersKwargs, auto_docstring
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import can_return_tuple
+from ..auto.modeling_auto import AutoModelForKeypointDetection
+from .configuration_lightglue import LightGlueConfig
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching,
+ the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the
+ batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask
+ tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint
+ matching information.
+ """
+)
+class LightGlueKeypointMatchingOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
+ Loss computed during training.
+ matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
+ Index of keypoint matched in the other image.
+ matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
+ Scores of predicted matches.
+ keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
+ Absolute (x, y) coordinates of predicted keypoints in a given image.
+ prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`):
+ Pruning mask indicating which keypoints are removed and at which layer.
+ mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`):
+ Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching
+ information.
+ hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
+ num_keypoints)` returned when `output_hidden_states=True` is passed or when
+ `config.output_hidden_states=True`
+ attentions (`Tuple[torch.FloatTensor, ...]`, *optional*):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
+ num_keypoints)` returned when `output_attentions=True` is passed or when
+ `config.output_attentions=True`
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ matches: Optional[torch.FloatTensor] = None
+ matching_scores: Optional[torch.FloatTensor] = None
+ keypoints: Optional[torch.FloatTensor] = None
+ prune: Optional[torch.IntTensor] = None
+ mask: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+class LightGluePositionalEncoder(nn.Module):
+ def __init__(self, config: LightGlueConfig):
+ super().__init__()
+ self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False)
+
+ def forward(
+ self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
+ ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
+ projected_keypoints = self.projector(keypoints)
+ embeddings = projected_keypoints.repeat_interleave(2, dim=-1)
+ cosines = torch.cos(embeddings)
+ sines = torch.sin(embeddings)
+ embeddings = (cosines, sines)
+ output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,)
+ return output
+
+
+def rotate_half(x):
+ # Split and rotate. Note that this function is different from e.g. Llama.
+ x1 = x[..., ::2]
+ x2 = x[..., 1::2]
+ rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
+ return rot_x
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ dtype = q.dtype
+ q = q.float()
+ k = k.float()
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class LightGlueAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: LightGlueConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ is_cross_attention = encoder_hidden_states is not None
+ current_states = encoder_hidden_states if is_cross_attention else hidden_states
+ current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
+
+ key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2)
+
+ if position_embeddings is not None:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ current_attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class LightGlueMLP(nn.Module):
+ def __init__(self, config: LightGlueConfig):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class LightGlueTransformerLayer(nn.Module):
+ def __init__(self, config: LightGlueConfig, layer_idx: int):
+ super().__init__()
+ self.self_attention = LightGlueAttention(config, layer_idx)
+ self.self_mlp = LightGlueMLP(config)
+ self.cross_attention = LightGlueAttention(config, layer_idx)
+ self.cross_mlp = LightGlueMLP(config)
+
+ def forward(
+ self,
+ descriptors: torch.Tensor,
+ keypoints: torch.Tensor,
+ attention_mask: torch.Tensor,
+ output_hidden_states: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]], Optional[tuple[torch.Tensor]]]:
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (descriptors,)
+
+ batch_size, num_keypoints, descriptor_dim = descriptors.shape
+
+ # Self attention block
+ attention_output, self_attentions = self.self_attention(
+ descriptors,
+ position_embeddings=keypoints,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ intermediate_states = torch.cat([descriptors, attention_output], dim=-1)
+ output_states = self.self_mlp(intermediate_states)
+ self_attention_descriptors = descriptors + output_states
+
+ if output_hidden_states:
+ self_attention_hidden_states = (intermediate_states, output_states)
+
+ # Reshape hidden_states to group by image_pairs :
+ # (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim)
+ # Flip dimension 1 to perform cross attention :
+ # (image0, image1) -> (image1, image0)
+ # Reshape back to original shape :
+ # (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim)
+ encoder_hidden_states = (
+ self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim)
+ .flip(1)
+ .reshape(batch_size, num_keypoints, descriptor_dim)
+ )
+ # Same for mask
+ encoder_attention_mask = (
+ attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints)
+ if attention_mask is not None
+ else None
+ )
+
+ # Cross attention block
+ cross_attention_output, cross_attentions = self.cross_attention(
+ self_attention_descriptors,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1)
+ cross_output_states = self.cross_mlp(cross_intermediate_states)
+ descriptors = self_attention_descriptors + cross_output_states
+
+ if output_hidden_states:
+ cross_attention_hidden_states = (cross_intermediate_states, cross_output_states)
+ all_hidden_states = (
+ all_hidden_states
+ + (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
+ + self_attention_hidden_states
+ + (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
+ + cross_attention_hidden_states
+ )
+
+ if output_attentions:
+ all_attentions = all_attentions + (self_attentions,) + (cross_attentions,)
+
+ return descriptors, all_hidden_states, all_attentions
+
+
+def sigmoid_log_double_softmax(
+ similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor
+) -> torch.Tensor:
+ """create the log assignment matrix from logits and similarity"""
+ batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape
+ certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2)
+ scores0 = nn.functional.log_softmax(similarity, 2)
+ scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
+ scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0)
+ scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties
+ scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1))
+ scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1))
+ return scores
+
+
+class LightGlueMatchAssignmentLayer(nn.Module):
+ def __init__(self, config: LightGlueConfig):
+ super().__init__()
+
+ self.descriptor_dim = config.descriptor_dim
+ self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True)
+ self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True)
+
+ def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ batch_size, num_keypoints, descriptor_dim = descriptors.shape
+ # Final projection and similarity computation
+ m_descriptors = self.final_projection(descriptors)
+ m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25
+ m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim)
+ m_descriptors0 = m_descriptors[:, 0]
+ m_descriptors1 = m_descriptors[:, 1]
+ similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2)
+ if mask is not None:
+ mask = mask.reshape(batch_size // 2, 2, num_keypoints)
+ mask0 = mask[:, 0].unsqueeze(-1)
+ mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2)
+ mask = mask0 * mask1
+ similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min)
+
+ # Compute matchability of descriptors
+ matchability = self.matchability(descriptors)
+ matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1)
+ matchability_0 = matchability[:, 0]
+ matchability_1 = matchability[:, 1]
+
+ # Compute scores from similarity and matchability
+ scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1)
+ return scores
+
+ def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor:
+ """Get matchability of descriptors as a probability"""
+ matchability = self.matchability(descriptors)
+ matchability = nn.functional.sigmoid(matchability).squeeze(-1)
+ return matchability
+
+
+class LightGlueTokenConfidenceLayer(nn.Module):
+ def __init__(self, config: LightGlueConfig):
+ super().__init__()
+
+ self.token = nn.Linear(config.descriptor_dim, 1)
+
+ def forward(self, descriptors: torch.Tensor) -> torch.Tensor:
+ token = self.token(descriptors.detach())
+ token = nn.functional.sigmoid(token).squeeze(-1)
+ return token
+
+
+@auto_docstring
+class LightGluePreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config: LightGlueConfig
+ base_model_prefix = "lightglue"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = False
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+
+def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]:
+ """obtain matches from a score matrix [Bx M+1 x N+1]"""
+ batch_size, _, _ = scores.shape
+ # For each keypoint, get the best match
+ max0 = scores[:, :-1, :-1].max(2)
+ max1 = scores[:, :-1, :-1].max(1)
+ matches0 = max0.indices
+ matches1 = max1.indices
+
+ # Mutual check for matches
+ indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None]
+ indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None]
+ mutual0 = indices0 == matches1.gather(1, matches0)
+ mutual1 = indices1 == matches0.gather(1, matches1)
+
+ # Get matching scores and filter based on mutual check and thresholding
+ max0 = max0.values.exp()
+ zero = max0.new_tensor(0)
+ matching_scores0 = torch.where(mutual0, max0, zero)
+ matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero)
+ valid0 = mutual0 & (matching_scores0 > threshold)
+ valid1 = mutual1 & valid0.gather(1, matches1)
+
+ # Filter matches based on mutual check and thresholding of scores
+ matches0 = torch.where(valid0, matches0, -1)
+ matches1 = torch.where(valid1, matches1, -1)
+ matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1)
+ matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1)
+
+ return matches, matching_scores
+
+
+def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ Normalize keypoints locations based on image image_shape
+
+ Args:
+ keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`):
+ Keypoints locations in (x, y) format.
+ height (`int`):
+ Image height.
+ width (`int`):
+ Image width.
+
+ Returns:
+ Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`).
+ """
+ size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None]
+ shift = size / 2
+ scale = size.max(-1).values / 2
+ keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None]
+ return keypoints
+
+
+@auto_docstring(
+ custom_intro="""
+ LightGlue model taking images as inputs and outputting the matching of them.
+ """
+)
+class LightGlueForKeypointMatching(LightGluePreTrainedModel):
+ """
+ LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as
+ SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient.
+ It consists of :
+ 1. Keypoint Encoder
+ 2. A Graph Neural Network with self and cross attention layers
+ 3. Matching Assignment layers
+
+ The correspondence ids use -1 to indicate non-matching points.
+
+ Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed.
+ In ICCV 2023. https://huggingface.co/papers/2306.13643
+ """
+
+ def __init__(self, config: LightGlueConfig):
+ super().__init__(config)
+ self.keypoint_detector = AutoModelForKeypointDetection.from_config(
+ config.keypoint_detector_config, trust_remote_code=config.trust_remote_code
+ )
+
+ self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim
+ self.descriptor_dim = config.descriptor_dim
+ self.num_layers = config.num_hidden_layers
+ self.filter_threshold = config.filter_threshold
+ self.depth_confidence = config.depth_confidence
+ self.width_confidence = config.width_confidence
+
+ if self.descriptor_dim != self.keypoint_detector_descriptor_dim:
+ self.input_projection = nn.Linear(self.keypoint_detector_descriptor_dim, self.descriptor_dim, bias=True)
+ else:
+ self.input_projection = nn.Identity()
+
+ self.positional_encoder = LightGluePositionalEncoder(config)
+
+ self.transformer_layers = nn.ModuleList(
+ [LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
+ )
+ self.match_assignment_layers = nn.ModuleList(
+ [LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.token_confidence = nn.ModuleList(
+ [LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)]
+ )
+
+ self.post_init()
+
+ def _get_confidence_threshold(self, layer_index: int) -> float:
+ """scaled confidence threshold for a given layer"""
+ threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers)
+ return np.clip(threshold, 0, 1)
+
+ def _keypoint_processing(
+ self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
+ descriptors = descriptors.detach().contiguous()
+ projected_descriptors = self.input_projection(descriptors)
+ keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states)
+ return projected_descriptors, keypoint_encoding_output
+
+ def _get_early_stopped_image_pairs(
+ self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor
+ ) -> torch.Tensor:
+ """evaluate whether we should stop inference based on the confidence of the keypoints"""
+ batch_size, _ = mask.shape
+ if layer_index < self.num_layers - 1:
+ # If the current layer is not the last layer, we compute the confidence of the keypoints and check
+ # if we should stop the forward pass through the transformer layers for each pair of images.
+ keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1)
+ keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1)
+ threshold = self._get_confidence_threshold(layer_index)
+ ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points
+ early_stopped_pairs = ratio_confident > self.depth_confidence
+ else:
+ # If the current layer is the last layer, we stop the forward pass through the transformer layers for
+ # all pairs of images.
+ early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
+ return early_stopped_pairs
+
+ def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None):
+ if early_stops is not None:
+ descriptors = descriptors[early_stops]
+ mask = mask[early_stops]
+ scores = self.match_assignment_layers[layer_index](descriptors, mask)
+ matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold)
+ return matches, matching_scores
+
+ def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor:
+ """mask points which should be removed"""
+ keep = scores > (1 - self.width_confidence)
+ if confidences is not None: # Low-confidence points are never pruned.
+ keep |= confidences <= self._get_confidence_threshold(layer_index)
+ return keep
+
+ def _do_layer_keypoint_pruning(
+ self,
+ descriptors: torch.Tensor,
+ keypoints: torch.Tensor,
+ mask: torch.Tensor,
+ indices: torch.Tensor,
+ prune_output: torch.Tensor,
+ keypoint_confidences: torch.Tensor,
+ layer_index: int,
+ ):
+ """
+ For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the
+ descriptors.
+ """
+ batch_size, _, _ = descriptors.shape
+ descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors)
+ pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index)
+ pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False))
+
+ # For each image, we extract the pruned indices and the corresponding descriptors and keypoints.
+ pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = (
+ [t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)]
+ for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices]
+ )
+ for i in range(batch_size):
+ prune_output[i, pruned_indices[i]] += 1
+
+ # Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch.
+ pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = (
+ pad_sequence(pruned_tensor, batch_first=True)
+ for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask]
+ )
+ pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1)
+ pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1)
+
+ return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output
+
+ def _concat_early_stopped_outputs(
+ self,
+ early_stops_indices,
+ final_pruned_keypoints_indices,
+ final_pruned_keypoints_iterations,
+ matches,
+ matching_scores,
+ ):
+ early_stops_indices = torch.stack(early_stops_indices)
+ # Rearrange tensors to have the same order as the input batch
+ ids = torch.arange(early_stops_indices.shape[0])
+ order_indices = early_stops_indices[ids]
+ early_stops_indices = early_stops_indices[order_indices]
+ matches, final_pruned_keypoints_indices = (
+ pad_sequence(tensor, batch_first=True, padding_value=-1)
+ for tensor in [matches, final_pruned_keypoints_indices]
+ )
+ matching_scores, final_pruned_keypoints_iterations = (
+ pad_sequence(tensor, batch_first=True, padding_value=0)
+ for tensor in [matching_scores, final_pruned_keypoints_iterations]
+ )
+ matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = (
+ tensor[early_stops_indices]
+ for tensor in [
+ matches,
+ matching_scores,
+ final_pruned_keypoints_indices,
+ final_pruned_keypoints_iterations,
+ ]
+ )
+ return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores
+
+ def _do_final_keypoint_pruning(
+ self,
+ indices: torch.Tensor,
+ matches: torch.Tensor,
+ matching_scores: torch.Tensor,
+ num_keypoints: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to
+ # have tensors from
+ batch_size, _ = indices.shape
+ indices, matches, matching_scores = (
+ tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores]
+ )
+ indices0 = indices[:, 0]
+ indices1 = indices[:, 1]
+ matches0 = matches[:, 0]
+ matches1 = matches[:, 1]
+ matching_scores0 = matching_scores[:, 0]
+ matching_scores1 = matching_scores[:, 1]
+
+ # Prepare final matches and matching scores
+ _matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype)
+ _matching_scores = torch.zeros(
+ (batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype
+ )
+ # Fill the matches and matching scores for each image pair
+ for i in range(batch_size // 2):
+ _matches[i, 0, indices0[i]] = torch.where(
+ matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0))
+ )
+ _matches[i, 1, indices1[i]] = torch.where(
+ matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0))
+ )
+ _matching_scores[i, 0, indices0[i]] = matching_scores0[i]
+ _matching_scores[i, 1, indices1[i]] = matching_scores1[i]
+ return _matches, _matching_scores
+
+ def _match_image_pair(
+ self,
+ keypoints: torch.Tensor,
+ descriptors: torch.Tensor,
+ height: int,
+ width: int,
+ mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple, tuple]:
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ if keypoints.shape[2] == 0: # no keypoints
+ shape = keypoints.shape[:-1]
+ return (
+ keypoints.new_full(shape, -1, dtype=torch.int),
+ keypoints.new_zeros(shape),
+ keypoints.new_zeros(shape),
+ all_hidden_states,
+ all_attentions,
+ )
+
+ device = keypoints.device
+ batch_size, _, initial_num_keypoints, _ = keypoints.shape
+ num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1)
+ # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
+ keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2)
+ mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None
+ descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.keypoint_detector_descriptor_dim)
+ image_indices = torch.arange(batch_size * 2, device=device)
+ # Keypoint normalization
+ keypoints = normalize_keypoints(keypoints, height, width)
+
+ descriptors, keypoint_encoding_output = self._keypoint_processing(
+ descriptors, keypoints, output_hidden_states=output_hidden_states
+ )
+
+ keypoints = keypoint_encoding_output[0]
+
+ # Early stop consists of stopping the forward pass through the transformer layers when the confidence of the
+ # keypoints is above a certain threshold.
+ do_early_stop = self.depth_confidence > 0
+ # Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of
+ # the keypoints is below a certain threshold.
+ do_keypoint_pruning = self.width_confidence > 0
+
+ early_stops_indices = []
+ matches = []
+ matching_scores = []
+ final_pruned_keypoints_indices = []
+ final_pruned_keypoints_iterations = []
+
+ pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1)
+ pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices)
+
+ for layer_index in range(self.num_layers):
+ input_shape = descriptors.size()
+ if mask is not None:
+ extended_attention_mask = self.get_extended_attention_mask(mask, input_shape)
+ else:
+ extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device)
+ layer_output = self.transformer_layers[layer_index](
+ descriptors,
+ keypoints,
+ attention_mask=extended_attention_mask,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ )
+ descriptors, hidden_states, attention = layer_output
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + hidden_states
+ if output_attentions:
+ all_attentions = all_attentions + attention
+
+ if do_early_stop:
+ if layer_index < self.num_layers - 1:
+ # Get the confidence of the keypoints for the current layer
+ keypoint_confidences = self.token_confidence[layer_index](descriptors)
+
+ # Determine which pairs of images should be early stopped based on the confidence of the keypoints for
+ # the current layer.
+ early_stopped_pairs = self._get_early_stopped_image_pairs(
+ keypoint_confidences, layer_index, mask, num_points=num_points_per_pair
+ )
+ else:
+ # Early stopping always occurs at the last layer
+ early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
+
+ if torch.any(early_stopped_pairs):
+ # If a pair of images is considered early stopped, we compute the matches for the remaining
+ # keypoints and stop the forward pass through the transformer layers for this pair of images.
+ early_stops = early_stopped_pairs.repeat_interleave(2)
+ early_stopped_image_indices = image_indices[early_stops]
+ early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching(
+ descriptors, mask, layer_index, early_stops=early_stops
+ )
+ early_stops_indices.extend(list(early_stopped_image_indices))
+ matches.extend(list(early_stopped_matches))
+ matching_scores.extend(list(early_stopped_matching_scores))
+ if do_keypoint_pruning:
+ final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops]))
+ final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops]))
+
+ # Remove image pairs that have been early stopped from the forward pass
+ num_points_per_pair = num_points_per_pair[~early_stopped_pairs]
+ descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple(
+ tensor[~early_stops]
+ for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices]
+ )
+ keypoints = (keypoints_0, keypoint_1)
+ if do_keypoint_pruning:
+ pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple(
+ tensor[~early_stops]
+ for tensor in [
+ pruned_keypoints_indices,
+ pruned_keypoints_iterations,
+ keypoint_confidences,
+ ]
+ )
+ # If all pairs of images are early stopped, we stop the forward pass through the transformer
+ # layers for all pairs of images.
+ if torch.all(early_stopped_pairs):
+ break
+
+ if do_keypoint_pruning:
+ # Prune keypoints from the input of the transformer layers for the next iterations if the confidence of
+ # the keypoints is below a certain threshold.
+ descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = (
+ self._do_layer_keypoint_pruning(
+ descriptors,
+ keypoints,
+ mask,
+ pruned_keypoints_indices,
+ pruned_keypoints_iterations,
+ keypoint_confidences,
+ layer_index,
+ )
+ )
+
+ if do_early_stop and do_keypoint_pruning:
+ # Concatenate early stopped outputs together and perform final keypoint pruning
+ final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = (
+ self._concat_early_stopped_outputs(
+ early_stops_indices,
+ final_pruned_keypoints_indices,
+ final_pruned_keypoints_iterations,
+ matches,
+ matching_scores,
+ )
+ )
+ matches, matching_scores = self._do_final_keypoint_pruning(
+ final_pruned_keypoints_indices,
+ matches,
+ matching_scores,
+ initial_num_keypoints,
+ )
+ else:
+ matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1)
+ final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers
+
+ final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape(
+ batch_size, 2, initial_num_keypoints
+ )
+
+ return (
+ matches,
+ matching_scores,
+ final_pruned_keypoints_iterations,
+ all_hidden_states,
+ all_attentions,
+ )
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> Union[tuple, LightGlueKeypointMatchingOutput]:
+ loss = None
+ if labels is not None:
+ raise ValueError("LightGlue is not trainable, no labels should be provided.")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
+ raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
+
+ batch_size, _, channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
+ keypoint_detections = self.keypoint_detector(pixel_values)
+
+ keypoints, _, descriptors, mask = keypoint_detections[:4]
+ keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
+ descriptors = descriptors.reshape(batch_size, 2, -1, self.keypoint_detector_descriptor_dim).to(pixel_values)
+ mask = mask.reshape(batch_size, 2, -1)
+
+ absolute_keypoints = keypoints.clone()
+ absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width
+ absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height
+
+ matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair(
+ absolute_keypoints,
+ descriptors,
+ height,
+ width,
+ mask=mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ return LightGlueKeypointMatchingOutput(
+ loss=loss,
+ matches=matches,
+ matching_scores=matching_scores,
+ keypoints=keypoints,
+ prune=prune,
+ mask=mask,
+ hidden_states=hidden_states,
+ attentions=attentions,
+ )
+
+
+__all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/modular_lightglue.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/modular_lightglue.py
new file mode 100644
index 0000000000000000000000000000000000000000..29441344c9cdfbdc9c74afa7bec493c19166844c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/lightglue/modular_lightglue.py
@@ -0,0 +1,1078 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import warnings
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn.utils.rnn import pad_sequence
+
+from ...configuration_utils import PretrainedConfig
+from ...image_utils import ImageInput, is_vision_available, to_numpy_array
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import ModelOutput, TensorType, auto_docstring, is_matplotlib_available, logging
+from ...utils.generic import can_return_tuple
+from ..auto import CONFIG_MAPPING, AutoConfig
+from ..auto.modeling_auto import AutoModelForKeypointDetection
+from ..clip.modeling_clip import CLIPMLP
+from ..cohere.modeling_cohere import apply_rotary_pos_emb
+from ..llama.modeling_llama import LlamaAttention, eager_attention_forward
+from ..superglue.image_processing_superglue import SuperGlueImageProcessor, validate_and_format_image_pairs
+from ..superpoint import SuperPointConfig
+
+
+if is_vision_available():
+ from PIL import Image, ImageDraw
+
+
+logger = logging.get_logger(__name__)
+
+
+class LightGlueConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LightGlueForKeypointMatching`]. It is used to
+ instantiate a LightGlue model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the LightGlue
+ [ETH-CVG/lightglue_superpoint](https://huggingface.co/ETH-CVG/lightglue_superpoint) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`):
+ The config object or dictionary of the keypoint detector.
+ descriptor_dim (`int`, *optional*, defaults to 256):
+ The dimension of the descriptors.
+ num_hidden_layers (`int`, *optional*, defaults to 9):
+ The number of self and cross attention layers.
+ num_attention_heads (`int`, *optional*, defaults to 4):
+ The number of heads in the multi-head attention.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ depth_confidence (`float`, *optional*, defaults to 0.95):
+ The confidence threshold used to perform early stopping
+ width_confidence (`float`, *optional*, defaults to 0.99):
+ The confidence threshold used to prune points
+ filter_threshold (`float`, *optional*, defaults to 0.1):
+ The confidence threshold used to filter matches
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The activation function to be used in the hidden layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ attention_bias (`bool`, *optional*, defaults to `True`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
+ Whether to trust remote code when using other models than SuperPoint as keypoint detector.
+
+ Examples:
+ ```python
+ >>> from transformers import LightGlueConfig, LightGlueForKeypointMatching
+
+ >>> # Initializing a LightGlue style configuration
+ >>> configuration = LightGlueConfig()
+
+ >>> # Initializing a model from the LightGlue style configuration
+ >>> model = LightGlueForKeypointMatching(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "lightglue"
+ sub_configs = {"keypoint_detector_config": AutoConfig}
+
+ def __init__(
+ self,
+ keypoint_detector_config: SuperPointConfig = None,
+ descriptor_dim: int = 256,
+ num_hidden_layers: int = 9,
+ num_attention_heads: int = 4,
+ num_key_value_heads=None,
+ depth_confidence: float = 0.95,
+ width_confidence: float = 0.99,
+ filter_threshold: float = 0.1,
+ initializer_range: float = 0.02,
+ hidden_act: str = "gelu",
+ attention_dropout=0.0,
+ attention_bias=True,
+ trust_remote_code: bool = False,
+ **kwargs,
+ ):
+ # LightGlue can be used with other models than SuperPoint as keypoint detector
+ # We provide the trust_remote_code argument to allow the use of other models
+ # that are not registered in the CONFIG_MAPPING dictionary (for example DISK)
+ self.trust_remote_code = trust_remote_code
+
+ if descriptor_dim % num_attention_heads != 0:
+ raise ValueError("descriptor_dim % num_heads is different from zero")
+
+ self.descriptor_dim = descriptor_dim
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+
+ self.depth_confidence = depth_confidence
+ self.width_confidence = width_confidence
+ self.filter_threshold = filter_threshold
+ self.initializer_range = initializer_range
+
+ # Keypoint Detector is forced into eager attention mode because SuperPoint does not have Attention
+ # See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153
+ if isinstance(keypoint_detector_config, dict):
+ keypoint_detector_config["model_type"] = keypoint_detector_config.get("model_type", "superpoint")
+ if keypoint_detector_config["model_type"] not in CONFIG_MAPPING:
+ keypoint_detector_config = AutoConfig.from_pretrained(
+ keypoint_detector_config["_name_or_path"], trust_remote_code=self.trust_remote_code
+ )
+ else:
+ keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]](
+ **keypoint_detector_config, attn_implementation="eager"
+ )
+
+ if keypoint_detector_config is None:
+ keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager")
+
+ self.keypoint_detector_config = keypoint_detector_config
+
+ self.hidden_size = descriptor_dim
+ self.intermediate_size = descriptor_dim * 2
+ self.hidden_act = hidden_act
+ self.attention_dropout = attention_dropout
+ self.attention_bias = attention_bias
+ super().__init__(**kwargs)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching,
+ the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the
+ batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask
+ tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint
+ matching information.
+ """
+)
+class LightGlueKeypointMatchingOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
+ Loss computed during training.
+ matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
+ Index of keypoint matched in the other image.
+ matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
+ Scores of predicted matches.
+ keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
+ Absolute (x, y) coordinates of predicted keypoints in a given image.
+ prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`):
+ Pruning mask indicating which keypoints are removed and at which layer.
+ mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`):
+ Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching
+ information.
+ hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
+ num_keypoints)` returned when `output_hidden_states=True` is passed or when
+ `config.output_hidden_states=True`
+ attentions (`Tuple[torch.FloatTensor, ...]`, *optional*):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
+ num_keypoints)` returned when `output_attentions=True` is passed or when
+ `config.output_attentions=True`
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ matches: Optional[torch.FloatTensor] = None
+ matching_scores: Optional[torch.FloatTensor] = None
+ keypoints: Optional[torch.FloatTensor] = None
+ prune: Optional[torch.IntTensor] = None
+ mask: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+class LightGlueImageProcessor(SuperGlueImageProcessor):
+ def post_process_keypoint_matching(
+ self,
+ outputs: LightGlueKeypointMatchingOutput,
+ target_sizes: Union[TensorType, list[tuple]],
+ threshold: float = 0.0,
+ ) -> list[dict[str, torch.Tensor]]:
+ return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
+
+ # Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->LightGlue
+ def visualize_keypoint_matching(
+ self,
+ images: ImageInput,
+ keypoint_matching_output: list[dict[str, torch.Tensor]],
+ ) -> list["Image.Image"]:
+ """
+ Plots the image pairs side by side with the detected keypoints as well as the matching between them.
+
+ Args:
+ images (`ImageInput`):
+ Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
+ images or a list of list of 2 images list with pixel values ranging from 0 to 255.
+ keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
+ A post processed keypoint matching output
+
+ Returns:
+ `List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
+ keypoints as well as the matching between them.
+ """
+ images = validate_and_format_image_pairs(images)
+ images = [to_numpy_array(image) for image in images]
+ image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
+
+ results = []
+ for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
+ height0, width0 = image_pair[0].shape[:2]
+ height1, width1 = image_pair[1].shape[:2]
+ plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
+ plot_image[:height0, :width0] = image_pair[0]
+ plot_image[:height1, width0:] = image_pair[1]
+
+ plot_image_pil = Image.fromarray(plot_image)
+ draw = ImageDraw.Draw(plot_image_pil)
+
+ keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
+ keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
+ for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
+ keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
+ ):
+ color = self._get_color(matching_score)
+ draw.line(
+ (keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
+ fill=color,
+ width=3,
+ )
+ draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
+ draw.ellipse(
+ (keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
+ fill="black",
+ )
+
+ results.append(plot_image_pil)
+ return results
+
+ # Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color
+ def _get_color(self, score):
+ """Maps a score to a color."""
+ r = int(255 * (1 - score))
+ g = int(255 * score)
+ b = 0
+ return (r, g, b)
+
+ def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
+ """
+ Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
+ matplotlib to be installed.
+
+ .. deprecated::
+ `plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
+
+ Args:
+ images (`ImageInput`):
+ Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
+ a list of list of 2 images list with pixel values ranging from 0 to 255.
+ keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
+ Raw outputs of the model.
+ """
+ warnings.warn(
+ "`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
+ "Use `visualize_keypoint_matching` instead.",
+ FutureWarning,
+ )
+
+ if is_matplotlib_available():
+ import matplotlib.pyplot as plt
+ else:
+ raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method")
+
+ images = validate_and_format_image_pairs(images)
+ images = [to_numpy_array(image) for image in images]
+ image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
+
+ for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
+ height0, width0 = image_pair[0].shape[:2]
+ height1, width1 = image_pair[1].shape[:2]
+ plot_image = np.zeros((max(height0, height1), width0 + width1, 3))
+ plot_image[:height0, :width0] = image_pair[0] / 255.0
+ plot_image[:height1, width0:] = image_pair[1] / 255.0
+ plt.imshow(plot_image)
+ plt.axis("off")
+
+ keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
+ keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
+ for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
+ keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
+ ):
+ plt.plot(
+ [keypoint0_x, keypoint1_x + width0],
+ [keypoint0_y, keypoint1_y],
+ color=plt.get_cmap("RdYlGn")(matching_score.item()),
+ alpha=0.9,
+ linewidth=0.5,
+ )
+ plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2)
+ plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2)
+ plt.show()
+
+
+class LightGluePositionalEncoder(nn.Module):
+ def __init__(self, config: LightGlueConfig):
+ super().__init__()
+ self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False)
+
+ def forward(
+ self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
+ ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
+ projected_keypoints = self.projector(keypoints)
+ embeddings = projected_keypoints.repeat_interleave(2, dim=-1)
+ cosines = torch.cos(embeddings)
+ sines = torch.sin(embeddings)
+ embeddings = (cosines, sines)
+ output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,)
+ return output
+
+
+class LightGlueAttention(LlamaAttention):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ is_cross_attention = encoder_hidden_states is not None
+ current_states = encoder_hidden_states if is_cross_attention else hidden_states
+ current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
+
+ key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2)
+
+ if position_embeddings is not None:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ current_attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class LightGlueMLP(CLIPMLP):
+ def __init__(self, config: LightGlueConfig):
+ super().__init__(config)
+ self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size)
+ self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class LightGlueTransformerLayer(nn.Module):
+ def __init__(self, config: LightGlueConfig, layer_idx: int):
+ super().__init__()
+ self.self_attention = LightGlueAttention(config, layer_idx)
+ self.self_mlp = LightGlueMLP(config)
+ self.cross_attention = LightGlueAttention(config, layer_idx)
+ self.cross_mlp = LightGlueMLP(config)
+
+ def forward(
+ self,
+ descriptors: torch.Tensor,
+ keypoints: torch.Tensor,
+ attention_mask: torch.Tensor,
+ output_hidden_states: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]], Optional[tuple[torch.Tensor]]]:
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (descriptors,)
+
+ batch_size, num_keypoints, descriptor_dim = descriptors.shape
+
+ # Self attention block
+ attention_output, self_attentions = self.self_attention(
+ descriptors,
+ position_embeddings=keypoints,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ intermediate_states = torch.cat([descriptors, attention_output], dim=-1)
+ output_states = self.self_mlp(intermediate_states)
+ self_attention_descriptors = descriptors + output_states
+
+ if output_hidden_states:
+ self_attention_hidden_states = (intermediate_states, output_states)
+
+ # Reshape hidden_states to group by image_pairs :
+ # (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim)
+ # Flip dimension 1 to perform cross attention :
+ # (image0, image1) -> (image1, image0)
+ # Reshape back to original shape :
+ # (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim)
+ encoder_hidden_states = (
+ self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim)
+ .flip(1)
+ .reshape(batch_size, num_keypoints, descriptor_dim)
+ )
+ # Same for mask
+ encoder_attention_mask = (
+ attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints)
+ if attention_mask is not None
+ else None
+ )
+
+ # Cross attention block
+ cross_attention_output, cross_attentions = self.cross_attention(
+ self_attention_descriptors,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1)
+ cross_output_states = self.cross_mlp(cross_intermediate_states)
+ descriptors = self_attention_descriptors + cross_output_states
+
+ if output_hidden_states:
+ cross_attention_hidden_states = (cross_intermediate_states, cross_output_states)
+ all_hidden_states = (
+ all_hidden_states
+ + (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
+ + self_attention_hidden_states
+ + (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
+ + cross_attention_hidden_states
+ )
+
+ if output_attentions:
+ all_attentions = all_attentions + (self_attentions,) + (cross_attentions,)
+
+ return descriptors, all_hidden_states, all_attentions
+
+
+def sigmoid_log_double_softmax(
+ similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor
+) -> torch.Tensor:
+ """create the log assignment matrix from logits and similarity"""
+ batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape
+ certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2)
+ scores0 = nn.functional.log_softmax(similarity, 2)
+ scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
+ scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0)
+ scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties
+ scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1))
+ scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1))
+ return scores
+
+
+class LightGlueMatchAssignmentLayer(nn.Module):
+ def __init__(self, config: LightGlueConfig):
+ super().__init__()
+
+ self.descriptor_dim = config.descriptor_dim
+ self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True)
+ self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True)
+
+ def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ batch_size, num_keypoints, descriptor_dim = descriptors.shape
+ # Final projection and similarity computation
+ m_descriptors = self.final_projection(descriptors)
+ m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25
+ m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim)
+ m_descriptors0 = m_descriptors[:, 0]
+ m_descriptors1 = m_descriptors[:, 1]
+ similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2)
+ if mask is not None:
+ mask = mask.reshape(batch_size // 2, 2, num_keypoints)
+ mask0 = mask[:, 0].unsqueeze(-1)
+ mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2)
+ mask = mask0 * mask1
+ similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min)
+
+ # Compute matchability of descriptors
+ matchability = self.matchability(descriptors)
+ matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1)
+ matchability_0 = matchability[:, 0]
+ matchability_1 = matchability[:, 1]
+
+ # Compute scores from similarity and matchability
+ scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1)
+ return scores
+
+ def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor:
+ """Get matchability of descriptors as a probability"""
+ matchability = self.matchability(descriptors)
+ matchability = nn.functional.sigmoid(matchability).squeeze(-1)
+ return matchability
+
+
+class LightGlueTokenConfidenceLayer(nn.Module):
+ def __init__(self, config: LightGlueConfig):
+ super().__init__()
+
+ self.token = nn.Linear(config.descriptor_dim, 1)
+
+ def forward(self, descriptors: torch.Tensor) -> torch.Tensor:
+ token = self.token(descriptors.detach())
+ token = nn.functional.sigmoid(token).squeeze(-1)
+ return token
+
+
+@auto_docstring
+class LightGluePreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config: LightGlueConfig
+ base_model_prefix = "lightglue"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = False
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+
+def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]:
+ """obtain matches from a score matrix [Bx M+1 x N+1]"""
+ batch_size, _, _ = scores.shape
+ # For each keypoint, get the best match
+ max0 = scores[:, :-1, :-1].max(2)
+ max1 = scores[:, :-1, :-1].max(1)
+ matches0 = max0.indices
+ matches1 = max1.indices
+
+ # Mutual check for matches
+ indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None]
+ indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None]
+ mutual0 = indices0 == matches1.gather(1, matches0)
+ mutual1 = indices1 == matches0.gather(1, matches1)
+
+ # Get matching scores and filter based on mutual check and thresholding
+ max0 = max0.values.exp()
+ zero = max0.new_tensor(0)
+ matching_scores0 = torch.where(mutual0, max0, zero)
+ matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero)
+ valid0 = mutual0 & (matching_scores0 > threshold)
+ valid1 = mutual1 & valid0.gather(1, matches1)
+
+ # Filter matches based on mutual check and thresholding of scores
+ matches0 = torch.where(valid0, matches0, -1)
+ matches1 = torch.where(valid1, matches1, -1)
+ matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1)
+ matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1)
+
+ return matches, matching_scores
+
+
+def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ Normalize keypoints locations based on image image_shape
+
+ Args:
+ keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`):
+ Keypoints locations in (x, y) format.
+ height (`int`):
+ Image height.
+ width (`int`):
+ Image width.
+
+ Returns:
+ Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`).
+ """
+ size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None]
+ shift = size / 2
+ scale = size.max(-1).values / 2
+ keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None]
+ return keypoints
+
+
+@auto_docstring(
+ custom_intro="""
+ LightGlue model taking images as inputs and outputting the matching of them.
+ """
+)
+class LightGlueForKeypointMatching(LightGluePreTrainedModel):
+ """
+ LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as
+ SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient.
+ It consists of :
+ 1. Keypoint Encoder
+ 2. A Graph Neural Network with self and cross attention layers
+ 3. Matching Assignment layers
+
+ The correspondence ids use -1 to indicate non-matching points.
+
+ Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed.
+ In ICCV 2023. https://huggingface.co/papers/2306.13643
+ """
+
+ def __init__(self, config: LightGlueConfig):
+ super().__init__(config)
+ self.keypoint_detector = AutoModelForKeypointDetection.from_config(
+ config.keypoint_detector_config, trust_remote_code=config.trust_remote_code
+ )
+
+ self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim
+ self.descriptor_dim = config.descriptor_dim
+ self.num_layers = config.num_hidden_layers
+ self.filter_threshold = config.filter_threshold
+ self.depth_confidence = config.depth_confidence
+ self.width_confidence = config.width_confidence
+
+ if self.descriptor_dim != self.keypoint_detector_descriptor_dim:
+ self.input_projection = nn.Linear(self.keypoint_detector_descriptor_dim, self.descriptor_dim, bias=True)
+ else:
+ self.input_projection = nn.Identity()
+
+ self.positional_encoder = LightGluePositionalEncoder(config)
+
+ self.transformer_layers = nn.ModuleList(
+ [LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
+ )
+ self.match_assignment_layers = nn.ModuleList(
+ [LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.token_confidence = nn.ModuleList(
+ [LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)]
+ )
+
+ self.post_init()
+
+ def _get_confidence_threshold(self, layer_index: int) -> float:
+ """scaled confidence threshold for a given layer"""
+ threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers)
+ return np.clip(threshold, 0, 1)
+
+ def _keypoint_processing(
+ self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
+ descriptors = descriptors.detach().contiguous()
+ projected_descriptors = self.input_projection(descriptors)
+ keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states)
+ return projected_descriptors, keypoint_encoding_output
+
+ def _get_early_stopped_image_pairs(
+ self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor
+ ) -> torch.Tensor:
+ """evaluate whether we should stop inference based on the confidence of the keypoints"""
+ batch_size, _ = mask.shape
+ if layer_index < self.num_layers - 1:
+ # If the current layer is not the last layer, we compute the confidence of the keypoints and check
+ # if we should stop the forward pass through the transformer layers for each pair of images.
+ keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1)
+ keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1)
+ threshold = self._get_confidence_threshold(layer_index)
+ ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points
+ early_stopped_pairs = ratio_confident > self.depth_confidence
+ else:
+ # If the current layer is the last layer, we stop the forward pass through the transformer layers for
+ # all pairs of images.
+ early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
+ return early_stopped_pairs
+
+ def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None):
+ if early_stops is not None:
+ descriptors = descriptors[early_stops]
+ mask = mask[early_stops]
+ scores = self.match_assignment_layers[layer_index](descriptors, mask)
+ matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold)
+ return matches, matching_scores
+
+ def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor:
+ """mask points which should be removed"""
+ keep = scores > (1 - self.width_confidence)
+ if confidences is not None: # Low-confidence points are never pruned.
+ keep |= confidences <= self._get_confidence_threshold(layer_index)
+ return keep
+
+ def _do_layer_keypoint_pruning(
+ self,
+ descriptors: torch.Tensor,
+ keypoints: torch.Tensor,
+ mask: torch.Tensor,
+ indices: torch.Tensor,
+ prune_output: torch.Tensor,
+ keypoint_confidences: torch.Tensor,
+ layer_index: int,
+ ):
+ """
+ For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the
+ descriptors.
+ """
+ batch_size, _, _ = descriptors.shape
+ descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors)
+ pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index)
+ pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False))
+
+ # For each image, we extract the pruned indices and the corresponding descriptors and keypoints.
+ pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = (
+ [t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)]
+ for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices]
+ )
+ for i in range(batch_size):
+ prune_output[i, pruned_indices[i]] += 1
+
+ # Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch.
+ pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = (
+ pad_sequence(pruned_tensor, batch_first=True)
+ for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask]
+ )
+ pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1)
+ pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1)
+
+ return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output
+
+ def _concat_early_stopped_outputs(
+ self,
+ early_stops_indices,
+ final_pruned_keypoints_indices,
+ final_pruned_keypoints_iterations,
+ matches,
+ matching_scores,
+ ):
+ early_stops_indices = torch.stack(early_stops_indices)
+ # Rearrange tensors to have the same order as the input batch
+ ids = torch.arange(early_stops_indices.shape[0])
+ order_indices = early_stops_indices[ids]
+ early_stops_indices = early_stops_indices[order_indices]
+ matches, final_pruned_keypoints_indices = (
+ pad_sequence(tensor, batch_first=True, padding_value=-1)
+ for tensor in [matches, final_pruned_keypoints_indices]
+ )
+ matching_scores, final_pruned_keypoints_iterations = (
+ pad_sequence(tensor, batch_first=True, padding_value=0)
+ for tensor in [matching_scores, final_pruned_keypoints_iterations]
+ )
+ matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = (
+ tensor[early_stops_indices]
+ for tensor in [
+ matches,
+ matching_scores,
+ final_pruned_keypoints_indices,
+ final_pruned_keypoints_iterations,
+ ]
+ )
+ return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores
+
+ def _do_final_keypoint_pruning(
+ self,
+ indices: torch.Tensor,
+ matches: torch.Tensor,
+ matching_scores: torch.Tensor,
+ num_keypoints: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to
+ # have tensors from
+ batch_size, _ = indices.shape
+ indices, matches, matching_scores = (
+ tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores]
+ )
+ indices0 = indices[:, 0]
+ indices1 = indices[:, 1]
+ matches0 = matches[:, 0]
+ matches1 = matches[:, 1]
+ matching_scores0 = matching_scores[:, 0]
+ matching_scores1 = matching_scores[:, 1]
+
+ # Prepare final matches and matching scores
+ _matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype)
+ _matching_scores = torch.zeros(
+ (batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype
+ )
+ # Fill the matches and matching scores for each image pair
+ for i in range(batch_size // 2):
+ _matches[i, 0, indices0[i]] = torch.where(
+ matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0))
+ )
+ _matches[i, 1, indices1[i]] = torch.where(
+ matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0))
+ )
+ _matching_scores[i, 0, indices0[i]] = matching_scores0[i]
+ _matching_scores[i, 1, indices1[i]] = matching_scores1[i]
+ return _matches, _matching_scores
+
+ def _match_image_pair(
+ self,
+ keypoints: torch.Tensor,
+ descriptors: torch.Tensor,
+ height: int,
+ width: int,
+ mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple, tuple]:
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ if keypoints.shape[2] == 0: # no keypoints
+ shape = keypoints.shape[:-1]
+ return (
+ keypoints.new_full(shape, -1, dtype=torch.int),
+ keypoints.new_zeros(shape),
+ keypoints.new_zeros(shape),
+ all_hidden_states,
+ all_attentions,
+ )
+
+ device = keypoints.device
+ batch_size, _, initial_num_keypoints, _ = keypoints.shape
+ num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1)
+ # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
+ keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2)
+ mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None
+ descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.keypoint_detector_descriptor_dim)
+ image_indices = torch.arange(batch_size * 2, device=device)
+ # Keypoint normalization
+ keypoints = normalize_keypoints(keypoints, height, width)
+
+ descriptors, keypoint_encoding_output = self._keypoint_processing(
+ descriptors, keypoints, output_hidden_states=output_hidden_states
+ )
+
+ keypoints = keypoint_encoding_output[0]
+
+ # Early stop consists of stopping the forward pass through the transformer layers when the confidence of the
+ # keypoints is above a certain threshold.
+ do_early_stop = self.depth_confidence > 0
+ # Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of
+ # the keypoints is below a certain threshold.
+ do_keypoint_pruning = self.width_confidence > 0
+
+ early_stops_indices = []
+ matches = []
+ matching_scores = []
+ final_pruned_keypoints_indices = []
+ final_pruned_keypoints_iterations = []
+
+ pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1)
+ pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices)
+
+ for layer_index in range(self.num_layers):
+ input_shape = descriptors.size()
+ if mask is not None:
+ extended_attention_mask = self.get_extended_attention_mask(mask, input_shape)
+ else:
+ extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device)
+ layer_output = self.transformer_layers[layer_index](
+ descriptors,
+ keypoints,
+ attention_mask=extended_attention_mask,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ )
+ descriptors, hidden_states, attention = layer_output
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + hidden_states
+ if output_attentions:
+ all_attentions = all_attentions + attention
+
+ if do_early_stop:
+ if layer_index < self.num_layers - 1:
+ # Get the confidence of the keypoints for the current layer
+ keypoint_confidences = self.token_confidence[layer_index](descriptors)
+
+ # Determine which pairs of images should be early stopped based on the confidence of the keypoints for
+ # the current layer.
+ early_stopped_pairs = self._get_early_stopped_image_pairs(
+ keypoint_confidences, layer_index, mask, num_points=num_points_per_pair
+ )
+ else:
+ # Early stopping always occurs at the last layer
+ early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
+
+ if torch.any(early_stopped_pairs):
+ # If a pair of images is considered early stopped, we compute the matches for the remaining
+ # keypoints and stop the forward pass through the transformer layers for this pair of images.
+ early_stops = early_stopped_pairs.repeat_interleave(2)
+ early_stopped_image_indices = image_indices[early_stops]
+ early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching(
+ descriptors, mask, layer_index, early_stops=early_stops
+ )
+ early_stops_indices.extend(list(early_stopped_image_indices))
+ matches.extend(list(early_stopped_matches))
+ matching_scores.extend(list(early_stopped_matching_scores))
+ if do_keypoint_pruning:
+ final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops]))
+ final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops]))
+
+ # Remove image pairs that have been early stopped from the forward pass
+ num_points_per_pair = num_points_per_pair[~early_stopped_pairs]
+ descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple(
+ tensor[~early_stops]
+ for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices]
+ )
+ keypoints = (keypoints_0, keypoint_1)
+ if do_keypoint_pruning:
+ pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple(
+ tensor[~early_stops]
+ for tensor in [
+ pruned_keypoints_indices,
+ pruned_keypoints_iterations,
+ keypoint_confidences,
+ ]
+ )
+ # If all pairs of images are early stopped, we stop the forward pass through the transformer
+ # layers for all pairs of images.
+ if torch.all(early_stopped_pairs):
+ break
+
+ if do_keypoint_pruning:
+ # Prune keypoints from the input of the transformer layers for the next iterations if the confidence of
+ # the keypoints is below a certain threshold.
+ descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = (
+ self._do_layer_keypoint_pruning(
+ descriptors,
+ keypoints,
+ mask,
+ pruned_keypoints_indices,
+ pruned_keypoints_iterations,
+ keypoint_confidences,
+ layer_index,
+ )
+ )
+
+ if do_early_stop and do_keypoint_pruning:
+ # Concatenate early stopped outputs together and perform final keypoint pruning
+ final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = (
+ self._concat_early_stopped_outputs(
+ early_stops_indices,
+ final_pruned_keypoints_indices,
+ final_pruned_keypoints_iterations,
+ matches,
+ matching_scores,
+ )
+ )
+ matches, matching_scores = self._do_final_keypoint_pruning(
+ final_pruned_keypoints_indices,
+ matches,
+ matching_scores,
+ initial_num_keypoints,
+ )
+ else:
+ matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1)
+ final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers
+
+ final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape(
+ batch_size, 2, initial_num_keypoints
+ )
+
+ return (
+ matches,
+ matching_scores,
+ final_pruned_keypoints_iterations,
+ all_hidden_states,
+ all_attentions,
+ )
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> Union[tuple, LightGlueKeypointMatchingOutput]:
+ loss = None
+ if labels is not None:
+ raise ValueError("LightGlue is not trainable, no labels should be provided.")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
+ raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
+
+ batch_size, _, channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
+ keypoint_detections = self.keypoint_detector(pixel_values)
+
+ keypoints, _, descriptors, mask = keypoint_detections[:4]
+ keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
+ descriptors = descriptors.reshape(batch_size, 2, -1, self.keypoint_detector_descriptor_dim).to(pixel_values)
+ mask = mask.reshape(batch_size, 2, -1)
+
+ absolute_keypoints = keypoints.clone()
+ absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width
+ absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height
+
+ matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair(
+ absolute_keypoints,
+ descriptors,
+ height,
+ width,
+ mask=mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ return LightGlueKeypointMatchingOutput(
+ loss=loss,
+ matches=matches,
+ matching_scores=matching_scores,
+ keypoints=keypoints,
+ prune=prune,
+ mask=mask,
+ hidden_states=hidden_states,
+ attentions=attentions,
+ )
+
+
+__all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching", "LightGlueConfig", "LightGlueImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aadd45dc13ed4074437cc6f224b0348de110f292
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_llava import *
+ from .image_processing_llava_fast import *
+ from .modeling_llava import *
+ from .processing_llava import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/configuration_llava.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/configuration_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ae710c011986f422d965adcdb758b29e3113690
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/configuration_llava.py
@@ -0,0 +1,137 @@
+# coding=utf-8
+# Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Llava model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class LlavaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
+ Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Llava-9B.
+
+ e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
+ The config object or dictionary of the text backbone.
+ image_token_index (`int`, *optional*, defaults to 32000):
+ The image token index to encode the image prompt.
+ projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The activation function used by the multimodal projector.
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`.
+ vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -2):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ image_seq_length (`int`, *optional*, defaults to 576):
+ Sequence length of one image embedding.
+ multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
+ Whether to use bias in the multimodal projector.
+
+ Example:
+
+ ```python
+ >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
+
+ >>> # Initializing a CLIP-vision config
+ >>> vision_config = CLIPVisionConfig()
+
+ >>> # Initializing a Llama config
+ >>> text_config = LlamaConfig()
+
+ >>> # Initializing a Llava llava-1.5-7b style configuration
+ >>> configuration = LlavaConfig(vision_config, text_config)
+
+ >>> # Initializing a model from the llava-1.5-7b style configuration
+ >>> model = LlavaForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "llava"
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ }
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ image_token_index=32000,
+ projector_hidden_act="gelu",
+ vision_feature_select_strategy="default",
+ vision_feature_layer=-2,
+ image_seq_length=576,
+ multimodal_projector_bias=True,
+ **kwargs,
+ ):
+ self.image_token_index = image_token_index
+ self.projector_hidden_act = projector_hidden_act
+ self.image_seq_length = image_seq_length
+
+ if vision_feature_select_strategy not in ["default", "full"]:
+ raise ValueError(
+ "vision_feature_select_strategy should be one of 'default', 'full'."
+ f"Got: {vision_feature_select_strategy}"
+ )
+
+ self.vision_feature_select_strategy = vision_feature_select_strategy
+ self.vision_feature_layer = vision_feature_layer
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "clip_vision_model")
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ elif vision_config is None:
+ vision_config = CONFIG_MAPPING["clip_vision_model"](
+ intermediate_size=4096,
+ hidden_size=1024,
+ patch_size=14,
+ image_size=336,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ vocab_size=32000,
+ projection_dim=768,
+ )
+
+ self.vision_config = vision_config
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "llama")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ text_config = CONFIG_MAPPING["llama"]()
+
+ self.text_config = text_config
+ self.multimodal_projector_bias = multimodal_projector_bias
+
+ super().__init__(**kwargs)
+
+
+__all__ = ["LlavaConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/image_processing_llava.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/image_processing_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..5420d6fe291861951c7e6b5bfc555b81c57cbc90
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/image_processing_llava.py
@@ -0,0 +1,437 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for LLaVa."""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ convert_to_rgb,
+ get_resize_output_image_size,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_kwargs,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, is_vision_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_vision_available():
+ import PIL
+
+
+class LlavaImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a LLaVa image processor.
+
+ Args:
+ do_pad (`bool`, *optional*, defaults to `False`):
+ Whether to pad the image to a square based on the longest edge.
+ The padding value is determined by the `image_mean` parameter.
+ Can be overridden by `do_pad` in the `preprocess` method.
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in the `preprocess` method.
+ size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
+ `preprocess` method.
+ crop_size (`dict[str, int]` *optional*, defaults to 224):
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
+ method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_pad: bool = False,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_center_crop: bool = True,
+ crop_size: Optional[dict[str, int]] = None,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 224}
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
+
+ self.do_pad = do_pad
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+ self.do_convert_rgb = do_convert_rgb
+ self._valid_processor_keys = [
+ "images",
+ "do_pad",
+ "do_resize",
+ "size",
+ "resample",
+ "do_center_crop",
+ "crop_size",
+ "do_rescale",
+ "rescale_factor",
+ "do_normalize",
+ "image_mean",
+ "image_std",
+ "do_convert_rgb",
+ "return_tensors",
+ "data_format",
+ "input_data_format",
+ ]
+
+ def pad_to_square(
+ self,
+ image: np.ndarray,
+ background_color: Union[int, tuple[int, int, int]] = 0,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pads an image to a square based on the longest edge.
+
+ Args:
+ image (`np.ndarray`):
+ The image to pad.
+ background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
+ The color to use for the padding. Can be an integer for single channel or a
+ tuple of integers representing for multi-channel images. If passed as integer
+ in multi-channel mode, it will default to `0` in subsequent channels.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+
+ Returns:
+ `np.ndarray`: The padded image.
+ """
+ height, width = get_image_size(image, input_data_format)
+ num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
+
+ if height == width:
+ image = (
+ to_channel_dimension_format(image, data_format, input_data_format)
+ if data_format is not None
+ else image
+ )
+ return image
+
+ max_dim = max(height, width)
+
+ # Ensure background_color is the correct shape
+ if isinstance(background_color, int):
+ background_color = [background_color]
+ elif len(background_color) != num_channels:
+ raise ValueError(
+ f"background_color must have no more than {num_channels} elements to match the number of channels"
+ )
+
+ if input_data_format == ChannelDimension.FIRST:
+ result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
+ for i, color in enumerate(background_color):
+ result[i, :, :] = color
+ if width > height:
+ start = (max_dim - height) // 2
+ result[:, start : start + height, :] = image
+ else:
+ start = (max_dim - width) // 2
+ result[:, :, start : start + width] = image
+ else:
+ result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
+ for i, color in enumerate(background_color):
+ result[:, :, i] = color
+ if width > height:
+ start = (max_dim - height) // 2
+ result[start : start + height, :, :] = image
+ else:
+ start = (max_dim - width) // 2
+ result[:, start : start + width, :] = image
+
+ image = (
+ to_channel_dimension_format(result, data_format, input_data_format) if data_format is not None else result
+ )
+ return image
+
+ # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
+ resized to keep the input aspect ratio.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ default_to_square = True
+ if "shortest_edge" in size:
+ size = size["shortest_edge"]
+ default_to_square = False
+ elif "height" in size and "width" in size:
+ size = (size["height"], size["width"])
+ else:
+ raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
+
+ output_size = get_resize_output_image_size(
+ image,
+ size=size,
+ default_to_square=default_to_square,
+ input_data_format=input_data_format,
+ )
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_pad: Optional[bool] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[int] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image to a square based on the longest edge.
+ The padding value is determined by the `image_mean` parameter.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size, param_name="size", default_to_square=False)
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
+
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ # we don't pass `do_pad` here since LLaVa uses a custom padding to a square
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if is_scaled_image(images[0]) and do_rescale:
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ processed_images = []
+ for image in images:
+ if do_pad:
+ image = self.pad_to_square(
+ image=image,
+ background_color=tuple(int(x * 255) for x in self.image_mean),
+ input_data_format=input_data_format,
+ )
+
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_center_crop:
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ processed_images.append(image)
+
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+
+__all__ = ["LlavaImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/image_processing_llava_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/image_processing_llava_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..5960700405494208325f970ee3a1495cfddc6a87
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/image_processing_llava_fast.py
@@ -0,0 +1,171 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for LLaVa."""
+
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+ get_image_size,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ auto_docstring,
+)
+
+
+class LlavaFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): ...
+
+
+@auto_docstring
+class LlavaImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ size = {"shortest_edge": 224}
+ default_to_square = False
+ crop_size = {"height": 224, "width": 224}
+ do_pad = False
+ do_resize = True
+ do_center_crop = True
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+ valid_kwargs = LlavaFastImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[LlavaFastImageProcessorKwargs]) -> None:
+ super().__init__(**kwargs)
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaFastImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def pad_to_square(
+ self,
+ images: "torch.Tensor",
+ background_color: Union[int, tuple[int, int, int]] = 0,
+ ) -> "torch.Tensor":
+ """
+ Pads an image to a square based on the longest edge.
+
+ Args:
+ images (`np.ndarray`):
+ The images to pad.
+ background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
+ The color to use for the padding. Can be an integer for single channel or a
+ tuple of integers representing for multi-channel images. If passed as integer
+ in multi-channel mode, it will default to `0` in subsequent channels.
+ Returns:
+ `torch.Tensor`: The padded images.
+ """
+ height, width = get_image_size(images, ChannelDimension.FIRST)
+
+ if height == width:
+ return images
+
+ num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
+ if isinstance(background_color, int):
+ background_color = [background_color] + [0] * (num_channels - 1)
+ elif len(background_color) != num_channels:
+ raise ValueError(
+ f"background_color must have no more than {num_channels} elements to match the number of channels"
+ )
+
+ max_dim = max(height, width)
+ paste_x_left = (max_dim - width) // 2
+ paste_y_left = (max_dim - height) // 2
+ paste_x_right = max_dim - width - paste_x_left
+ paste_y_right = max_dim - height - paste_y_left
+ padded_images = F.pad(
+ images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color
+ )
+
+ return padded_images
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_pad: bool,
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_pad:
+ stacked_images = self.pad_to_square(
+ images=stacked_images, background_color=tuple(int(x * 255) for x in self.image_mean)
+ )
+ resized_images_grouped[shape] = stacked_images
+ padded_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for batched resizing
+ # Needed in case do_pad is False, or padding returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(padded_images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_center_crop:
+ stacked_images = self.center_crop(stacked_images, crop_size)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+
+__all__ = ["LlavaImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/modeling_llava.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/modeling_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc0bb0df7c7baa60aaee0f2fd585102368bea1db
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/modeling_llava.py
@@ -0,0 +1,484 @@
+# coding=utf-8
+# Copyright 2023 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Llava model."""
+
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ..auto import AutoModel
+from .configuration_llava import LlavaConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class LlavaModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava causal language model (or autoregressive) outputs.
+ """
+)
+class LlavaCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+class LlavaMultiModalProjector(nn.Module):
+ def __init__(self, config: LlavaConfig):
+ super().__init__()
+ # We have hidden_size * the number of vision feature layers
+ num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
+ self.linear_1 = nn.Linear(
+ config.vision_config.hidden_size * num_feature_layers,
+ config.text_config.hidden_size,
+ bias=config.multimodal_projector_bias,
+ )
+ self.act = ACT2FN[config.projector_hidden_act]
+ self.linear_2 = nn.Linear(
+ config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
+ )
+
+ def forward(self, image_features):
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+@auto_docstring
+class LlavaPreTrainedModel(PreTrainedModel):
+ config: LlavaConfig
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _skip_keys_device_placement = "past_key_values"
+
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+
+
+@auto_docstring(
+ custom_intro="""
+ The Llava model which consists of a vision backbone and a language model, without a language modeling head.
+ """
+)
+class LlavaModel(LlavaPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
+ def __init__(self, config: LlavaConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config.vision_config)
+
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
+ self.language_model = AutoModel.from_config(config.text_config)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
+ The tensors corresponding to the input images.
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if vision_feature_select_strategy not in ["default", "full"]:
+ raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
+
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
+
+ # If we have one vision feature layer, return the corresponding hidden states,
+ # otherwise, select the hidden states of each feature layer and concatenate them
+ if isinstance(vision_feature_layer, int):
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
+ if vision_feature_select_strategy == "default":
+ selected_image_feature = selected_image_feature[:, 1:]
+ else:
+ hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
+ # For default; crop CLS from each hidden state in the hidden state pool
+ if vision_feature_select_strategy == "default":
+ hs_pool = [hs[:, 1:] for hs in hs_pool]
+ selected_image_feature = torch.cat(hs_pool, dim=-1)
+
+ image_features = self.multi_modal_projector(selected_image_feature)
+
+ if "image_sizes" in kwargs:
+ split_sizes = [
+ (height // self.vision_tower.patch_size) * (width // self.vision_tower.patch_size)
+ for height, width in kwargs["image_sizes"]
+ ]
+ image_features = torch.split(image_features.squeeze(0), split_sizes)
+ else:
+ image_features = list(image_features)
+ return image_features
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ return special_image_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, LlavaModelOutputWithPast]:
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_sizes=image_sizes,
+ )
+ image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return LlavaModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The LLAVA model which consists of a vision backbone and a language model.
+ """
+)
+class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: LlavaConfig):
+ super().__init__(config)
+ self.model = LlavaModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ **kwargs,
+ ):
+ return self.model.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ **kwargs,
+ )
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ labels: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ image_sizes: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, LlavaCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
+
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
+
+ >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:"
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
+ ```"""
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ cache_position=cache_position,
+ image_sizes=image_sizes,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return LlavaCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+
+ return model_inputs
+
+
+__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel", "LlavaModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/processing_llava.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/processing_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..63c07c20cbb9974fb6d2597e88e5224cd0111a52
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava/processing_llava.py
@@ -0,0 +1,212 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for Llava.
+"""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput, get_image_size, to_numpy_array
+from ...processing_utils import (
+ MultiModalData,
+ ProcessingKwargs,
+ ProcessorMixin,
+ Unpack,
+)
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class LlavaProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {"padding": False, "return_mm_token_type_ids": False},
+ "images_kwargs": {},
+ }
+
+
+class LlavaProcessor(ProcessorMixin):
+ r"""
+ Constructs a LLaVa processor which wraps a LLaVa image processor and a LLaMa tokenizer into a single processor.
+
+ [`LlavaProcessor`] offers all the functionalities of [`LlavaImageProcessor`] and [`LlamaTokenizerFast`]. See the
+ [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`LlavaImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`LlamaTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ patch_size (`int`, *optional*):
+ Patch size from the vision tower.
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Should be same as in model's config
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ image_token (`str`, *optional*, defaults to `""`):
+ Special token used to denote image location.
+ num_additional_image_tokens (`int`, *optional*, defaults to 0):
+ Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other
+ extra tokens appended, no need to set this arg.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer=None,
+ patch_size=None,
+ vision_feature_select_strategy=None,
+ chat_template=None,
+ image_token="", # set the default and let users change if they have peculiar special tokens in rare cases
+ num_additional_image_tokens=0,
+ **kwargs,
+ ):
+ self.patch_size = patch_size
+ self.num_additional_image_tokens = num_additional_image_tokens
+ self.vision_feature_select_strategy = vision_feature_select_strategy
+ self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
+ self.image_token_id = tokenizer.encode(self.image_token, add_special_tokens=False)[0]
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ audio=None,
+ videos=None,
+ **kwargs: Unpack[LlavaProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
+ of the above two methods for more information.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+ if images is None and text is None:
+ raise ValueError("You have to specify at least one of `images` or `text`.")
+
+ output_kwargs = self._merge_kwargs(
+ LlavaProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ if images is not None:
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
+ else:
+ image_inputs = {}
+
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list) and not isinstance(text[0], str):
+ raise TypeError("Invalid input text. Please provide a string, or a list of strings")
+
+ # try to expand inputs in processing if we have the necessary parts
+ prompt_strings = text
+ if image_inputs.get("pixel_values") is not None:
+ # Replace the image token with the expanded image token sequence
+ pixel_values = image_inputs["pixel_values"]
+ height, width = get_image_size(to_numpy_array(pixel_values[0]))
+ num_image_tokens = (height // self.patch_size) * (
+ width // self.patch_size
+ ) + self.num_additional_image_tokens
+ if self.vision_feature_select_strategy == "default":
+ num_image_tokens -= 1
+
+ prompt_strings = []
+ for sample in text:
+ sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
+ prompt_strings.append(sample)
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
+ self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+
+ Args:
+ image_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (height, width) per each image.
+
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = LlavaProcessorKwargs._defaults.get("images_kwargs", {})
+ images_kwargs.update(kwargs)
+ crop_size = images_kwargs.get("crop_size", None) or self.image_processor.crop_size
+ resized_height, resized_width = crop_size["height"], crop_size["width"]
+
+ num_image_tokens = (resized_height // self.patch_size) * (resized_width // self.patch_size)
+ num_image_tokens += self.num_additional_image_tokens
+ if self.vision_feature_select_strategy == "default":
+ num_image_tokens -= 1
+
+ num_image_tokens = [num_image_tokens] * len(image_sizes)
+ num_image_patches = [1] * len(image_sizes)
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+
+ return MultiModalData(**vision_data)
+
+
+__all__ = ["LlavaProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c8429dc7e80c1ced93d7fa79b1b36d472eec26e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_llava_next import *
+ from .image_processing_llava_next import *
+ from .image_processing_llava_next_fast import *
+ from .modeling_llava_next import *
+ from .processing_llava_next import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/configuration_llava_next.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/configuration_llava_next.py
new file mode 100644
index 0000000000000000000000000000000000000000..17ea71b1aa6421c8e2007cbb9399a69e9894288e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/configuration_llava_next.py
@@ -0,0 +1,150 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Llava-NeXT model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class LlavaNextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LlavaNextForConditionalGeneration`]. It is used to instantiate an
+ Llava-NeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the [llava-hf/llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)
+ model.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
+ The config object or dictionary of the text backbone.
+ image_token_index (`int`, *optional*, defaults to 32000):
+ The image token index to encode the image prompt.
+ projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The activation function used by the multimodal projector.
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
+ If `"full"`, the full vision features are used.
+ vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -2):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ image_grid_pinpoints (`List`, *optional*, defaults to `[[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]`):
+ A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ image_seq_length (`int`, *optional*, defaults to 576):
+ Sequence length of one image embedding.
+ multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
+ Whether to use bias in the multimodal projector.
+
+ Example:
+
+ ```python
+ >>> from transformers import LlavaNextForConditionalGeneration, LlavaNextConfig, CLIPVisionConfig, LlamaConfig
+
+ >>> # Initializing a CLIP-vision config
+ >>> vision_config = CLIPVisionConfig()
+
+ >>> # Initializing a Llama config
+ >>> text_config = LlamaConfig()
+
+ >>> # Initializing a Llava-Next llava-hf/llava-v1.6-mistral-7b-hf style configuration
+ >>> configuration = LlavaNextConfig(vision_config, text_config)
+
+ >>> # Initializing a model from the llava-hf/llava-v1.6-mistral-7b-hf style configuration
+ >>> model = LlavaNextForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "llava_next"
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ }
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ image_token_index=32000,
+ projector_hidden_act="gelu",
+ vision_feature_select_strategy="default",
+ vision_feature_layer=-2,
+ image_grid_pinpoints=None,
+ tie_word_embeddings=False,
+ image_seq_length=576,
+ multimodal_projector_bias=True,
+ **kwargs,
+ ):
+ self.image_token_index = image_token_index
+ self.projector_hidden_act = projector_hidden_act
+ self.image_seq_length = image_seq_length
+ self.multimodal_projector_bias = multimodal_projector_bias
+
+ if vision_feature_select_strategy not in ["default", "full"]:
+ raise ValueError(
+ "vision_feature_select_strategy should be one of 'default', 'full'."
+ f"Got: {vision_feature_select_strategy}"
+ )
+
+ self.vision_feature_select_strategy = vision_feature_select_strategy
+ self.vision_feature_layer = vision_feature_layer
+ image_grid_pinpoints = (
+ image_grid_pinpoints
+ if image_grid_pinpoints is not None
+ else [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
+ )
+ self.image_grid_pinpoints = image_grid_pinpoints
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "clip_vision_model")
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ elif vision_config is None:
+ vision_config = CONFIG_MAPPING["clip_vision_model"](
+ intermediate_size=4096,
+ hidden_size=1024,
+ patch_size=14,
+ image_size=336,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ vocab_size=32000,
+ projection_dim=768,
+ )
+
+ self.vision_config = vision_config
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "llama")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ text_config = CONFIG_MAPPING["llama"]()
+
+ self.text_config = text_config
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+__all__ = ["LlavaNextConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/image_processing_llava_next.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/image_processing_llava_next.py
new file mode 100644
index 0000000000000000000000000000000000000000..350ce9db7dc6f53e11090bf8f222665f8fcaf9a9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/image_processing_llava_next.py
@@ -0,0 +1,724 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for LLaVa-NeXT."""
+
+from collections.abc import Iterable
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import (
+ BaseImageProcessor,
+ BatchFeature,
+ get_patch_output_size,
+ get_size_dict,
+ select_best_resolution,
+)
+from ...image_transforms import (
+ PaddingMode,
+ convert_to_rgb,
+ get_resize_output_image_size,
+ pad,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, is_vision_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_vision_available():
+ from PIL import Image
+
+
+def divide_to_patches(image: np.ndarray, patch_size: int, input_data_format) -> list[np.ndarray]:
+ """
+ Divides an image into patches of a specified size.
+
+ Args:
+ image (`np.ndarray`):
+ The input image.
+ patch_size (`int`):
+ The size of each patch.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ list: A list of np.ndarray representing the patches.
+ """
+ patches = []
+ height, width = get_image_size(image, channel_dim=input_data_format)
+ for i in range(0, height, patch_size):
+ for j in range(0, width, patch_size):
+ if input_data_format == ChannelDimension.LAST:
+ patch = image[i : i + patch_size, j : j + patch_size]
+ else:
+ patch = image[:, i : i + patch_size, j : j + patch_size]
+ patches.append(patch)
+
+ return patches
+
+
+def expand_to_square(image: np.ndarray, background_color, input_data_format) -> np.ndarray:
+ """
+ Expands an image to a square by adding a background color.
+ """
+
+ height, width = get_image_size(image, channel_dim=input_data_format)
+ if width == height:
+ return image
+ elif width > height:
+ result = np.ones((width, width, image.shape[2]), dtype=image.dtype) * background_color
+ result[(width - height) // 2 : (width - height) // 2 + height, :] = image
+ return result
+ else:
+ result = np.ones((height, height, image.shape[2]), dtype=image.dtype) * background_color
+ result[:, (height - width) // 2 : (height - width) // 2 + width] = image
+ return result
+
+
+class LlavaNextImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a LLaVa-NeXT image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques
+ for processing high resolution images as explained in the [LLaVa paper](https://huggingface.co/papers/2310.03744).
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in the `preprocess` method.
+ size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
+ method.
+ image_grid_pinpoints (`List` *optional*, defaults to `[[672, 336], [336, 672], [672, 672], [336, 1008], [1008, 336]]`):
+ A list of possible resolutions to use for processing high resolution images. The best resolution is selected
+ based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
+ `preprocess` method.
+ crop_size (`dict[str, int]` *optional*, defaults to 224):
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
+ method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
+ number of patches in the batch. Padding will be applied to the bottom and right with zeros.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ """
+
+ model_input_names = ["pixel_values", "image_sizes"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ image_grid_pinpoints: Optional[list] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_center_crop: bool = True,
+ crop_size: Optional[dict[str, int]] = None,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = True,
+ do_convert_rgb: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 224}
+ size = get_size_dict(size, default_to_square=False)
+ image_grid_pinpoints = (
+ image_grid_pinpoints
+ if image_grid_pinpoints is not None
+ else [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
+ )
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
+
+ self.do_resize = do_resize
+ self.size = size
+ self.image_grid_pinpoints = image_grid_pinpoints
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+ self.do_pad = do_pad
+ self.do_convert_rgb = do_convert_rgb
+
+ # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize with CLIP->LLaVa
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
+ resized to keep the input aspect ratio.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ default_to_square = True
+ if "shortest_edge" in size:
+ size = size["shortest_edge"]
+ default_to_square = False
+ elif "height" in size and "width" in size:
+ size = (size["height"], size["width"])
+ else:
+ raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
+
+ output_size = get_resize_output_image_size(
+ image,
+ size=size,
+ default_to_square=default_to_square,
+ input_data_format=input_data_format,
+ )
+
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def pad(
+ self,
+ image: np.ndarray,
+ padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]],
+ mode: PaddingMode = PaddingMode.CONSTANT,
+ constant_values: Union[float, Iterable[float]] = 0.0,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
+ dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
+ as input.
+
+ Args:
+ image (`np.ndarray`):
+ The image to pad.
+ padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
+ Padding to apply to the edges of the height, width axes. Can be one of three formats:
+ - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
+ - `((before, after),)` yields same before and after pad for height and width.
+ - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
+ mode (`PaddingMode`):
+ The padding mode to use. Can be one of:
+ - `"constant"`: pads with a constant value.
+ - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
+ vector along each axis.
+ - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
+ - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
+ constant_values (`float` or `Iterable[float]`, *optional*):
+ The value to use for the padding if `mode` is `"constant"`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+
+ Returns:
+ `np.ndarray`: The padded image.
+
+ """
+
+ # call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
+ if isinstance(padding, int) or len(padding) != 4:
+ return pad(image, padding, mode, constant_values, data_format, input_data_format)
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+ if mode == PaddingMode.CONSTANT:
+ image = np.pad(image, padding, mode="constant", constant_values=constant_values)
+ elif mode == PaddingMode.REFLECT:
+ image = np.pad(image, padding, mode="reflect")
+ elif mode == PaddingMode.REPLICATE:
+ image = np.pad(image, padding, mode="edge")
+ elif mode == PaddingMode.SYMMETRIC:
+ image = np.pad(image, padding, mode="symmetric")
+ else:
+ raise ValueError(f"Invalid padding mode: {mode}")
+ image = (
+ to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
+ )
+ return image
+
+ def _preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[int] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> Image.Image:
+ """
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ images = make_flat_list_of_images(images)
+
+ all_images = []
+ for image in images:
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_center_crop:
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ all_images.append(image)
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ for image in all_images
+ ]
+
+ return images
+
+ def _resize_for_patching(
+ self, image: np.ndarray, target_resolution: tuple, resample, input_data_format: ChannelDimension
+ ) -> np.ndarray:
+ """
+ Resizes an image to a target resolution while maintaining aspect ratio.
+
+ Args:
+ image (np.ndarray):
+ The input image.
+ target_resolution (tuple):
+ The target resolution (height, width) of the image.
+ resample (`PILImageResampling`):
+ Resampling filter to use if resizing the image.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ np.ndarray: The resized and padded image.
+ """
+ new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
+
+ # Resize the image
+ resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
+
+ return resized_image
+
+ def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
+ original_height, original_width = original_resolution
+ target_height, target_width = target_resolution
+ paste_x, r_x = divmod(target_width - original_width, 2)
+ paste_y, r_y = divmod(target_height - original_height, 2)
+ return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
+
+ def _pad_for_patching(
+ self, image: np.ndarray, target_resolution: tuple, input_data_format: ChannelDimension
+ ) -> np.ndarray:
+ """
+ Pad an image to a target resolution while maintaining aspect ratio.
+ """
+ new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
+ padding = self._get_padding_size(new_resolution, target_resolution)
+
+ padded_image = self.pad(image, padding=padding)
+
+ return padded_image
+
+ def get_image_patches(
+ self,
+ image: np.ndarray,
+ grid_pinpoints,
+ size: tuple,
+ patch_size: int,
+ resample: PILImageResampling,
+ data_format: ChannelDimension,
+ input_data_format: ChannelDimension,
+ ) -> list[np.ndarray]:
+ """
+ Process an image with variable resolutions by dividing it into patches.
+
+ Args:
+ image (np.ndarray):
+ The input image to be processed.
+ grid_pinpoints (List):
+ A string representation of a list of possible resolutions.
+ size (`tuple`):
+ Size to resize the original image to.
+ patch_size (`int`):
+ Size of the patches to divide the image into.
+ resample (`PILImageResampling`):
+ Resampling filter to use if resizing the image.
+ data_format (`ChannelDimension` or `str`):
+ The channel dimension format for the output image.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ list[np.ndarray]: A list of NumPy arrays containing the processed image patches.
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints must be a list of possible resolutions.")
+
+ possible_resolutions = grid_pinpoints
+
+ image_size = get_image_size(image, channel_dim=input_data_format)
+ best_resolution = select_best_resolution(image_size, possible_resolutions)
+ resized_image = self._resize_for_patching(
+ image, best_resolution, resample=resample, input_data_format=input_data_format
+ )
+ padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
+
+ patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format)
+
+ # make sure that all patches are in the input data format
+ patches = [
+ to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format)
+ for patch in patches
+ ]
+
+ resized_original_image = resize(
+ image,
+ size=size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+
+ image_patches = [resized_original_image] + patches
+
+ return image_patches
+
+ def _pad_for_batching(
+ self,
+ pixel_values: list[np.ndarray],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
+
+ Args:
+ pixel_values (`list[np.ndarray]`):
+ An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+
+ Returns:
+ list[`np.ndarray`]: The padded images.
+ """
+ max_patch = max(len(x) for x in pixel_values)
+ pixel_values = [
+ self.pad(
+ image,
+ padding=((0, max_patch - image.shape[0]), (0, 0), (0, 0), (0, 0)),
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ for image in pixel_values
+ ]
+
+ return pixel_values
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ image_grid_pinpoints: Optional[list] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[int] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ do_convert_rgb: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ image_grid_pinpoints (`List` *optional*, defaults to `self.image_grid_pinpoints`):
+ A list of possible resolutions to use for processing high resolution images. The best resolution is
+ selected based on the original size of the image.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
+ number of patches in the batch. Padding will be applied to the bottom and right with zeros.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size, param_name="size", default_to_square=False)
+ image_grid_pinpoints = image_grid_pinpoints if image_grid_pinpoints is not None else self.image_grid_pinpoints
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ processed_images = []
+ image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
+ for image in images:
+ # convert image into a list of patches
+ # we intentionally use the same data format as the input data format
+ image_patches = self.get_image_patches(
+ image,
+ image_grid_pinpoints,
+ size=(size["shortest_edge"], size["shortest_edge"])
+ if "shortest_edge" in size
+ else (min(size["height"], size["width"]), min(size["height"], size["width"])),
+ patch_size=crop_size["height"],
+ resample=resample,
+ data_format=input_data_format,
+ input_data_format=input_data_format,
+ )
+
+ # preprocess patches
+ pixel_values = self._preprocess(
+ image_patches,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ pixel_values = np.array(pixel_values)
+ processed_images.append(pixel_values)
+
+ if do_pad:
+ processed_images = self._pad_for_batching(processed_images)
+
+ return BatchFeature(
+ data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
+ )
+
+
+__all__ = ["LlavaNextImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/image_processing_llava_next_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/image_processing_llava_next_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..df20e2b90e8323ff17ed2c80d6d5369aba85b428
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/image_processing_llava_next_fast.py
@@ -0,0 +1,281 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for LLaVa-NeXT."""
+
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils import BatchFeature, get_patch_output_size, select_best_resolution
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ divide_to_patches,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+ get_image_size,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ auto_docstring,
+)
+
+
+class LlavaNextFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ image_grid_pinpoints (`list[list[int]]`, *optional*):
+ A list of possible resolutions to use for processing high resolution images. The best resolution is selected
+ based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
+ method.
+ """
+
+ image_grid_pinpoints: Optional[list[list[int]]]
+
+
+@auto_docstring
+class LlavaNextImageProcessorFast(BaseImageProcessorFast):
+ # To be checked against the slow image processor
+ # None values left after checking can be removed
+ resample = PILImageResampling.BICUBIC
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ size = {"shortest_edge": 224}
+ default_to_square = False
+ crop_size = {"height": 224, "width": 224}
+ do_resize = True
+ do_center_crop = True
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+ do_pad = True
+ image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
+ valid_kwargs = LlavaNextFastImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[LlavaNextFastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaNextFastImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def _resize_for_patching(
+ self,
+ image: "torch.Tensor",
+ target_resolution: tuple,
+ interpolation: "F.InterpolationMode",
+ input_data_format: ChannelDimension,
+ ) -> "torch.Tensor":
+ """
+ Resizes an image to a target resolution while maintaining aspect ratio.
+
+ Args:
+ image ("torch.Tensor"):
+ The input image.
+ target_resolution (tuple):
+ The target resolution (height, width) of the image.
+ interpolation (`InterpolationMode`):
+ Resampling filter to use if resizing the image.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ "torch.Tensor": The resized and padded image.
+ """
+ new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
+
+ # Resize the image
+ resized_image = self.resize(
+ image=image,
+ size=SizeDict(height=new_height, width=new_width),
+ interpolation=interpolation,
+ )
+
+ return resized_image
+
+ def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
+ original_height, original_width = original_resolution
+ target_height, target_width = target_resolution
+ paste_x, r_x = divmod(target_width - original_width, 2)
+ paste_y, r_y = divmod(target_height - original_height, 2)
+ return [paste_x, paste_y, paste_x + r_x, paste_y + r_y]
+
+ def _pad_for_patching(
+ self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
+ ) -> "torch.Tensor":
+ """
+ Pad an image to a target resolution while maintaining aspect ratio.
+ """
+ new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
+ padding = self._get_padding_size(new_resolution, target_resolution)
+
+ padded_image = F.pad(image, padding=padding)
+
+ return padded_image
+
+ def _get_image_patches(
+ self,
+ image: "torch.Tensor",
+ grid_pinpoints,
+ size: tuple,
+ patch_size: int,
+ interpolation: "F.InterpolationMode",
+ ) -> list["torch.Tensor"]:
+ """
+ Process an image with variable resolutions by dividing it into patches.
+
+ Args:
+ image ("torch.Tensor"):
+ The input image to be processed.
+ grid_pinpoints (List):
+ A string representation of a list of possible resolutions.
+ size (`tuple`):
+ Size to resize the original image to.
+ patch_size (`int`):
+ Size of the patches to divide the image into.
+ interpolation (`"InterpolationMode"`):
+ Resampling filter to use if resizing the image.
+
+ Returns:
+ list["torch.Tensor"]: A list of NumPy arrays containing the processed image patches.
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints must be a list of possible resolutions.")
+
+ possible_resolutions = grid_pinpoints
+
+ image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
+ best_resolution = select_best_resolution(image_size, possible_resolutions)
+ resized_image = self._resize_for_patching(
+ image, best_resolution, interpolation=interpolation, input_data_format=ChannelDimension.FIRST
+ )
+ padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=ChannelDimension.FIRST)
+ patches = divide_to_patches(padded_image, patch_size=patch_size)
+ resized_original_image = F.resize(image, size=size, interpolation=interpolation)
+
+ image_patches = [resized_original_image] + patches
+
+ return image_patches
+
+ def _pad_for_batching(
+ self,
+ pixel_values: list["torch.Tensor"],
+ ) -> list["torch.Tensor"]:
+ """
+ Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
+
+ Args:
+ pixel_values (`list[torch.Tensor]`):
+ An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
+
+ Returns:
+ list[`torch.Tensor`]: The padded images.
+ """
+ max_patch = max(len(x) for x in pixel_values)
+ pixel_values = [
+ torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
+ for image in pixel_values
+ ]
+
+ return pixel_values
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ image_grid_pinpoints: list[list[int]],
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ do_pad: bool,
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ processed_images = []
+ image_sizes = []
+ # Determine the size tuple
+ if size and size.height and size.width:
+ size_tuple = (size.height, size.width)
+ else:
+ size_tuple = (size.shortest_edge, size.shortest_edge)
+
+ # Determine the patch size
+ if crop_size and crop_size.height:
+ patch_size = crop_size.height
+ elif size and size.height:
+ patch_size = size.height
+ else:
+ patch_size = size.shortest_edge
+
+ for image in images:
+ image_patches = self._get_image_patches(
+ image,
+ image_grid_pinpoints,
+ size=size_tuple,
+ patch_size=patch_size,
+ interpolation=interpolation,
+ )
+
+ # Group images by size for batched processing
+ processed_image_patches_grouped = {}
+ grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
+ image_patches, disable_grouping=disable_grouping
+ )
+ for shape, stacked_image_patches in grouped_image_patches.items():
+ if do_resize:
+ stacked_image_patches = self.resize(
+ image=stacked_image_patches,
+ size=size,
+ interpolation=interpolation,
+ )
+ if do_center_crop:
+ stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
+ # Fused rescale and normalize
+ stacked_image_patches = self.rescale_and_normalize(
+ stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_image_patches_grouped[shape] = stacked_image_patches
+ processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
+ processed_image_patches = (
+ torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
+ )
+ processed_images.append(processed_image_patches)
+ image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
+
+ if do_pad:
+ processed_images = self._pad_for_batching(processed_images)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+ return BatchFeature(
+ data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
+ )
+
+
+__all__ = ["LlavaNextImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/modeling_llava_next.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/modeling_llava_next.py
new file mode 100644
index 0000000000000000000000000000000000000000..a75b4b7981078cd690ae7d063e5715dce8bf6696
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/modeling_llava_next.py
@@ -0,0 +1,793 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Llava-NeXT model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...image_processing_utils import select_best_resolution
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ..auto import AutoModel
+from .configuration_llava_next import LlavaNextConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+ """
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (`tuple`):
+ The size of the input image in the format (width, height).
+ grid_pinpoints (`List`):
+ A list containing possible resolutions. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ patch_size (`int`):
+ The size of each image patch.
+
+ Returns:
+ tuple: The shape of the image patch grid in the format (width, height).
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints should be a list of tuples or lists")
+
+ # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
+ if not isinstance(image_size, (list, tuple)):
+ if not isinstance(image_size, (torch.Tensor, np.ndarray)):
+ raise TypeError(
+ f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
+ )
+ image_size = image_size.tolist()
+
+ height, width = select_best_resolution(image_size, grid_pinpoints)
+ return height // patch_size, width // patch_size
+
+
+def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
+ """
+ Calculate the number of patches after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (`torch.LongTensor` or `np.ndarray` or `tuple[int, int]`):
+ The size of the input image in the format (height, width). ?
+ grid_pinpoints (`List`):
+ A list containing possible resolutions. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ patch_size (`int`):
+ The size of each image patch.
+
+ Returns:
+ int: the number of patches
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints should be a list of tuples or lists")
+
+ # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
+ if not isinstance(image_size, (list, tuple)):
+ if not isinstance(image_size, (torch.Tensor, np.ndarray)):
+ raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
+ image_size = image_size.tolist()
+
+ best_resolution = select_best_resolution(image_size, grid_pinpoints)
+ height, width = best_resolution
+ num_patches = 0
+ # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
+ for i in range(0, height, patch_size):
+ for j in range(0, width, patch_size):
+ num_patches += 1
+ # add the base patch
+ num_patches += 1
+ return num_patches
+
+
+def unpad_image(tensor, original_size):
+ """
+ Unpads a PyTorch tensor of a padded and resized image.
+
+ Args:
+ tensor (`torch.Tensor`):
+ The image tensor, assumed to be of shape (num_channels, height, width).
+ original_size (`tuple`):
+ The original size of the image (height, width).
+
+ Returns:
+ `torch.Tensor`: The unpadded image tensor.
+ """
+ if not isinstance(original_size, (list, tuple)):
+ if not isinstance(original_size, (torch.Tensor, np.ndarray)):
+ raise TypeError(
+ f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
+ )
+ original_size = original_size.tolist()
+ original_height, original_width = original_size
+ current_height, current_width = tensor.shape[1:]
+
+ original_aspect_ratio = original_width / original_height
+ current_aspect_ratio = current_width / current_height
+
+ if original_aspect_ratio > current_aspect_ratio:
+ scale_factor = current_width / original_width
+ new_height = int(round(original_height * scale_factor, 7))
+ padding = (current_height - new_height) // 2
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
+ else:
+ scale_factor = current_height / original_height
+ new_width = int(round(original_width * scale_factor, 7))
+ padding = (current_width - new_width) // 2
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
+
+ return unpadded_tensor
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class LlavaNextModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for LlavaNext causal language model (or autoregressive) outputs.
+ """
+)
+class LlavaNextCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
+class LlavaNextMultiModalProjector(nn.Module):
+ def __init__(self, config: LlavaNextConfig):
+ super().__init__()
+ # We have hidden_size * the number of vision feature layers
+ num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
+ self.linear_1 = nn.Linear(
+ config.vision_config.hidden_size * num_feature_layers,
+ config.text_config.hidden_size,
+ bias=config.multimodal_projector_bias,
+ )
+ self.act = ACT2FN[config.projector_hidden_act]
+ self.linear_2 = nn.Linear(
+ config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
+ )
+
+ def forward(self, image_features):
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+@auto_docstring
+class LlavaNextPreTrainedModel(PreTrainedModel):
+ config: LlavaNextConfig
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
+
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, LlavaNextModel):
+ embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
+ module.image_newline.data.normal_(mean=0.0, std=embed_std)
+
+
+@auto_docstring(
+ custom_intro="""
+ The Llava-Next model which consists of a vision backbone and a language model without language modeling head.
+ """
+)
+class LlavaNextModel(LlavaNextPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
+ def __init__(self, config: LlavaNextConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config.vision_config)
+
+ self.multi_modal_projector = LlavaNextMultiModalProjector(config)
+ embed_std = 1 / math.sqrt(config.text_config.hidden_size)
+ self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
+
+ self.vocab_size = config.text_config.vocab_size
+ self.language_model = AutoModel.from_config(config.text_config)
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
+ """
+ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
+
+ Args:
+ image_features (`list[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
+ List of image feature tensor, each contains all the visual feature of all patches.
+ image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
+ Actual image size of each images (H, W).
+ vision_feature_select_strategy (`str`)
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ image_newline (`torch.Tensor` of shape `(embed_dim)`)
+ New line embedding vector.
+ Returns:
+ image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
+ feature_lens (`list[int]`)
+ token length of each image in image_features
+ """
+ new_image_features = []
+ feature_lens = []
+ for image_idx, image_feature in enumerate(image_features):
+ if image_feature.shape[0] > 1:
+ base_image_feature = image_feature[0]
+ image_feature = image_feature[1:]
+ height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
+
+ num_patch_height, num_patch_width = get_anyres_image_grid_shape(
+ image_sizes[image_idx],
+ self.config.image_grid_pinpoints,
+ self.config.vision_config.image_size,
+ )
+
+ if (
+ np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
+ and vision_feature_select_strategy == "default"
+ ):
+ logger.warning_once(
+ "Image feature shape does not line up with the provided patch size. "
+ "You may be using the `default` vision_feature_select_strategy with a"
+ " visual encoder that does not have CLS."
+ )
+
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
+ if image_newline is not None:
+ image_feature = torch.cat(
+ (
+ image_feature,
+ image_newline[:, None, None]
+ .expand(*image_feature.shape[:-1], 1)
+ .to(image_feature.device, image_feature.dtype),
+ ),
+ dim=-1,
+ )
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
+ else:
+ image_feature = image_feature[0]
+ if image_newline is not None:
+ image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
+ new_image_features.append(image_feature)
+ feature_lens.append(image_feature.size(0))
+ feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device)
+ return new_image_features, feature_lens
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_sizes: torch.Tensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
+ The tensors corresponding to the input images.
+ image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
+ Actual image size of each images (H, W).
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ Returns:
+ image_features (list[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
+ and are of shape `(num_patches, image_length, embed_dim)`).
+ """
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ # ! infer image_num_patches from image_sizes
+ image_num_patches = [
+ image_size_to_num_patches(
+ image_size=imsize,
+ grid_pinpoints=self.config.image_grid_pinpoints,
+ patch_size=self.config.vision_config.image_size,
+ )
+ for imsize in image_sizes
+ ]
+ if pixel_values.dim() == 5:
+ # stacked if input is (batch_size, num_patches, num_channels, height, width)
+ _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
+ pixel_values = torch.cat(_pixel_values_list, dim=0)
+ elif pixel_values.dim() != 4:
+ # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
+ raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
+
+ image_features = self.vision_tower(pixel_values, output_hidden_states=True)
+ # If we have one vision feature layer, return the corresponding hidden states,
+ # otherwise, select the hidden states of each feature layer and concatenate them
+ if isinstance(vision_feature_layer, int):
+ selected_image_feature = image_features.hidden_states[vision_feature_layer]
+ else:
+ hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
+ selected_image_feature = torch.cat(hs_pool, dim=-1)
+
+ if vision_feature_select_strategy == "default":
+ selected_image_feature = selected_image_feature[:, 1:]
+
+ image_features = self.multi_modal_projector(selected_image_feature)
+ image_features = torch.split(image_features, image_num_patches, dim=0)
+
+ # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
+ image_features, feature_lens = self.pack_image_features(
+ image_features,
+ image_sizes,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_newline=self.image_newline,
+ )
+ return image_features
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
+ )
+ return special_image_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, LlavaNextModelOutputWithPast]:
+ r"""
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
+ If `"full"`, the full vision features are used.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None and pixel_values.size(0) > 0:
+ image_features = self.get_image_features(
+ pixel_values,
+ image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+ image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return LlavaNextModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The LLAVA-NeXT model which consists of a vision backbone and a language model.
+ """
+)
+class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^image_newline": "model.image_newline",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: LlavaNextConfig):
+ super().__init__(config)
+ self.model = LlavaNextModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
+ return self.model.pack_image_features(
+ image_features=image_features,
+ image_sizes=image_sizes,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_newline=image_newline,
+ )
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_sizes: torch.Tensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ ):
+ return self.model.get_image_features(
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, LlavaNextCausalLMOutputWithPast]:
+ r"""
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
+ If `"full"`, the full vision features are used.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, LlavaNextForConditionalGeneration
+
+ >>> model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
+
+ >>> prompt = "[INST] \nWhat is shown in this image? [/INST]"
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_length=30)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot (...)"
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ outputs = self.model(
+ input_ids,
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return LlavaNextCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ image_sizes=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ if cache_position[0] == 0:
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["image_sizes"] = image_sizes
+
+ return model_inputs
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = ["LlavaNextForConditionalGeneration", "LlavaNextPreTrainedModel", "LlavaNextModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/processing_llava_next.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/processing_llava_next.py
new file mode 100644
index 0000000000000000000000000000000000000000..2574fc443519f928a1d8f14bed08ae15466950c5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/llava_next/processing_llava_next.py
@@ -0,0 +1,266 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for LLaVa-NeXT.
+"""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_processing_utils import select_best_resolution
+from ...image_utils import ImageInput, get_image_size, to_numpy_array
+from ...processing_utils import (
+ MultiModalData,
+ ProcessingKwargs,
+ ProcessorMixin,
+ Unpack,
+)
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class LlavaNextProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "return_mm_token_type_ids": False,
+ },
+ "images_kwargs": {
+ "do_pad": True,
+ },
+ }
+
+
+class LlavaNextProcessor(ProcessorMixin):
+ r"""
+ Constructs a LLaVa-NeXT processor which wraps a LLaVa-NeXT image processor and a LLaMa tokenizer into a single processor.
+
+ [`LlavaNextProcessor`] offers all the functionalities of [`LlavaNextImageProcessor`] and [`LlamaTokenizerFast`]. See the
+ [`~LlavaNextProcessor.__call__`] and [`~LlavaNextProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`LlavaNextImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`LlamaTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ patch_size (`int`, *optional*):
+ Patch size from the vision tower.
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Should be same as in model's config
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ image_token (`str`, *optional*, defaults to `""`):
+ Special token used to denote image location.
+ num_additional_image_tokens (`int`, *optional*, defaults to 0):
+ Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other
+ extra tokens appended, no need to set this arg.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer=None,
+ patch_size=None,
+ vision_feature_select_strategy=None,
+ chat_template=None,
+ image_token="", # set the default and let users change if they have peculiar special tokens in rare cases
+ num_additional_image_tokens=0,
+ **kwargs,
+ ):
+ self.patch_size = patch_size
+ self.num_additional_image_tokens = num_additional_image_tokens
+ self.vision_feature_select_strategy = vision_feature_select_strategy
+ self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
+ self.image_token_id = (
+ tokenizer.image_token_id
+ if getattr(tokenizer, "image_token_id", None)
+ else tokenizer.convert_tokens_to_ids(self.image_token)
+ )
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ audio=None,
+ videos=None,
+ **kwargs: Unpack[LlavaNextProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
+ LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
+ of the above two methods for more information.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+ if images is None and text is None:
+ raise ValueError("You have to specify at least images or text.")
+
+ output_kwargs = self._merge_kwargs(
+ LlavaNextProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ if images is not None:
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
+ else:
+ image_inputs = {}
+
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list) and not isinstance(text[0], str):
+ raise TypeError("Invalid input text. Please provide a string, or a list of strings")
+
+ prompt_strings = text
+ if image_inputs:
+ image_sizes = iter(image_inputs["image_sizes"])
+ height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
+ prompt_strings = []
+ for sample in text:
+ while self.image_token in sample:
+ image_size = next(image_sizes)
+ if not isinstance(image_size, (list, tuple)):
+ # cast to list to avoid numerical precision errors when calculating unpadding
+ image_size = image_size.tolist()
+ orig_height, orig_width = image_size
+ num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
+ if self.vision_feature_select_strategy == "default":
+ num_image_tokens -= 1
+ sample = sample.replace(self.image_token, "" * num_image_tokens, 1)
+ prompt_strings.append(sample)
+ prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings]
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
+ self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
+
+ def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
+ image_grid_pinpoints = self.image_processor.image_grid_pinpoints
+
+ height_best_resolution, width_best_resolution = select_best_resolution(
+ [orig_height, orig_width], image_grid_pinpoints
+ )
+ scale_height, scale_width = height_best_resolution // height, width_best_resolution // width
+
+ patches_height = height // self.patch_size
+ patches_width = width // self.patch_size
+ unpadded_features, newline_features = self._get_unpadded_features(
+ orig_height, orig_width, patches_height, patches_width, scale_height, scale_width
+ )
+ # The base patch covers the entire image (+1 for the CLS)
+ base_features = patches_height * patches_width + self.num_additional_image_tokens
+ num_image_tokens = unpadded_features + newline_features + base_features
+ return num_image_tokens
+
+ def _get_unpadded_features(self, height, width, patches_height, patches_width, scale_height, scale_width):
+ """
+ Get number of features for a given image with height/width. LLaVA-NeXT is different from LLaVA
+ because it divided each image into patches depending on its resolution. Therefore we need to calculate how many
+ patches an image is divided into and get the number of features from that.
+ """
+ current_height = patches_height * scale_height
+ current_width = patches_width * scale_width
+
+ original_aspect_ratio = width / height
+ current_aspect_ratio = current_width / current_height
+ if original_aspect_ratio > current_aspect_ratio:
+ new_height = int(round(height * (current_width / width), 7))
+ padding = (current_height - new_height) // 2
+ current_height -= padding * 2
+ else:
+ new_width = int(round(width * (current_height / height), 7))
+ padding = (current_width - new_width) // 2
+ current_width -= padding * 2
+
+ unpadded_features = current_height * current_width
+ newline_features = current_height
+ return (unpadded_features, newline_features)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+ Args:
+ image_sizes (list[list[str]], *optional*):
+ The input sizes formatted as (height, width) per each image.
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = LlavaNextProcessorKwargs._defaults.get("images_kwargs", {})
+ images_kwargs.update(kwargs)
+
+ size = images_kwargs.get("size", None) or self.image_processor.size
+ size = (
+ (size["shortest_edge"], size["shortest_edge"])
+ if "shortest_edge" in size
+ else (min(size["height"], size["width"]), min(size["height"], size["width"]))
+ )
+ processed_height, processed_width = size
+
+ batch_num_image_tokens = []
+ num_image_patches = [1] * len(image_sizes) # llava-next doesn't batch pixels as Idefics, thus `1` patch`
+ for image_size in image_sizes:
+ orig_height, orig_width = image_size
+ num_image_tokens = self._get_number_of_features(
+ orig_height, orig_width, processed_height, processed_width
+ )
+ if self.vision_feature_select_strategy == "default":
+ num_image_tokens -= 1
+ batch_num_image_tokens.append(num_image_tokens)
+ vision_data.update({"num_image_tokens": batch_num_image_tokens, "num_image_patches": num_image_patches})
+
+ return MultiModalData(**vision_data)
+
+
+__all__ = ["LlavaNextProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longcat_flash/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longcat_flash/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9a9429d9d05c06395282a24be9c2b60057359ed
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longcat_flash/__init__.py
@@ -0,0 +1,29 @@
+# coding=utf-8
+# Copyright 2025 Meituan and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_longcat_flash import *
+ from .modeling_longcat_flash import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longcat_flash/configuration_longcat_flash.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longcat_flash/configuration_longcat_flash.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c5930db8f3ab1e26236f6db778859f77736764f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longcat_flash/configuration_longcat_flash.py
@@ -0,0 +1,235 @@
+# coding=utf-8
+# Copyright 2025 Meituan and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""LongCat Flash model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class LongcatFlashConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LongcatFlashModel`]. It is used to instantiate
+ a LongCat Flash model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the LongCat Flash architecture.
+ e.g. [meituan-longcat/LongCat-Flash-Chat](https://huggingface.co/meituan-longcat/LongCat-Flash-Chat)
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 131072):
+ Vocabulary size of the LongCat Flash model. Defines the number of different tokens that can be represented by the
+ `input_ids` passed when calling [`LongcatFlashModel`]
+ hidden_size (`int`, *optional*, defaults to 6144):
+ Dimension of the hidden representations.
+ num_hidden_layers (`int`, *optional*, defaults to 56):
+ Number of hidden layers in the Transformer decoder.
+ num_layers (`int`, *optional*, defaults to 28):
+ number of layers, each with 2 sublayers.
+ num_attention_heads (`int`, *optional*, defaults to 64):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting from a multi-head checkpoint to a GQA checkpoint, each group key and value head should be
+ constructed by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon value used by the RMS normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie input and output embeddings.
+ rope_theta (`float`, *optional*, defaults to 10000000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ ffn_hidden_size (`int`, *optional*, defaults to 12288):
+ Dimension of the MLP representations.
+ q_lora_rank (`int`, *optional*, defaults to 1536):
+ The rank of the query LoRA projection in MLA (Multi-head Latent Attention).
+ kv_lora_rank (`int`, *optional*, defaults to 512):
+ The rank of the key-value LoRA projection in MLA.
+ qk_nope_head_dim (`int`, *optional*, defaults to 128):
+ The dimension of the non-position encoding part of query/key heads.
+ qk_rope_head_dim (`int`, *optional*, defaults to 64):
+ The dimension of the RoPE part of query/key heads.
+ head_dim (`int`, *optional*, defaults to 64):
+ Standard dimension of qk heads, unused except for CI.
+ v_head_dim (`int`, *optional*, defaults to 128):
+ The dimension of value heads.
+ qk_head_dim (`int`, *optional*):
+ The total dimension of query/key heads. If not specified, set to `qk_nope_head_dim + qk_rope_head_dim`.
+ moe_topk (`int`, *optional*, defaults to 12):
+ Number of experts to route to for each token in the MoE layer.
+ n_routed_experts (`int`, *optional*, defaults to 512):
+ Number of routed experts in the MoE layer.
+ zero_expert_num (`int`, *optional*, defaults to 256):
+ Number of zero experts (identity function) to add to the expert pool.
+ expert_ffn_hidden_size (`int`, *optional*, defaults to 2048):
+ Hidden size of individual expert FFN layers.
+ routed_scaling_factor (`float`, *optional*, defaults to 6.0):
+ Scaling factor applied to the routing weights.
+
+ ```python
+ >>> from transformers import LongcatFlashModel, LongcatFlashConfig
+
+ >>> # Initializing a LongCat Flash style configuration
+ >>> configuration = LongcatFlashConfig()
+
+ >>> # Initializing a model from the configuration
+ >>> model = LongcatFlashModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "longcat_flash"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.*.q_b_proj": "colwise",
+ "layers.*.self_attn.*.kv_b_proj": "colwise",
+ "layers.*.self_attn.*.o_proj": "rowwise",
+ "layers.*.mlps.*.gate_proj": "colwise",
+ "layers.*.mlps.*.up_proj": "colwise",
+ "layers.*.mlps.*.down_proj": "rowwise",
+ "layers.*.mlp.experts.*.gate_proj": "colwise",
+ "layers.*.mlp.experts.*.up_proj": "colwise",
+ "layers.*.mlp.experts.*.down_proj": "rowwise",
+ }
+
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=131072,
+ hidden_size=6144,
+ num_hidden_layers=56,
+ num_layers=28,
+ num_attention_heads=64,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=131072,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=10000000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ ffn_hidden_size=12288,
+ q_lora_rank=1536,
+ kv_lora_rank=512,
+ qk_nope_head_dim=128,
+ qk_rope_head_dim=64,
+ head_dim=64,
+ v_head_dim=128,
+ qk_head_dim=None,
+ moe_topk=12,
+ n_routed_experts=512,
+ zero_expert_num=256,
+ expert_ffn_hidden_size=2048,
+ routed_scaling_factor=6.0,
+ **kwargs,
+ ):
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ if qk_head_dim is None:
+ qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
+
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ self.ffn_hidden_size = ffn_hidden_size
+
+ self.q_lora_rank = q_lora_rank
+ self.kv_lora_rank = kv_lora_rank
+ self.qk_nope_head_dim = qk_nope_head_dim
+ self.qk_rope_head_dim = qk_rope_head_dim
+ self.v_head_dim = v_head_dim
+ self.qk_head_dim = qk_head_dim
+ self.head_dim = head_dim
+
+ self.moe_topk = moe_topk
+ self.n_routed_experts = n_routed_experts
+ self.zero_expert_num = zero_expert_num
+ self.expert_ffn_hidden_size = expert_ffn_hidden_size
+ self.routed_scaling_factor = routed_scaling_factor
+
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+
+ if self.rope_scaling is not None:
+ for key in ["beta_fast", "beta_slow", "factor"]:
+ if key in self.rope_scaling:
+ self.rope_scaling[key] = float(self.rope_scaling[key])
+
+ rope_config_validation(self)
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+__all__ = ["LongcatFlashConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longcat_flash/modeling_longcat_flash.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longcat_flash/modeling_longcat_flash.py
new file mode 100644
index 0000000000000000000000000000000000000000..4681cfb60e53ff6cf89a0588a2e220740bda71ac
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longcat_flash/modeling_longcat_flash.py
@@ -0,0 +1,684 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/longcat_flash/modular_longcat_flash.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_longcat_flash.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Meituan and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_longcat_flash import LongcatFlashConfig
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class LongcatFlashRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LongcatFlashRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class LongcatFlashRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: LongcatFlashConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class LongcatFlashMLP(nn.Module):
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
+ self.intermediate_size = config.ffn_hidden_size if intermediate_size is None else intermediate_size
+
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class LongcatFlashTopkRouter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ self.top_k = config.moe_topk
+ self.n_routed_experts = config.n_routed_experts + (config.zero_expert_num or 0)
+ self.routed_scaling_factor = config.routed_scaling_factor
+ self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts))
+ self.router_bias = getattr(config, "router_bias", False)
+ self.classifier = nn.Linear(config.hidden_size, self.n_routed_experts, bias=self.router_bias)
+
+ @torch.no_grad()
+ def get_topk_indices(self, scores):
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
+ return topk_indices
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
+ router_logits = F.linear(hidden_states.type(torch.float32), self.classifier.weight.type(torch.float32))
+ scores = router_logits.softmax(dim=-1)
+ topk_indices = self.get_topk_indices(scores)
+ topk_weights = scores.gather(1, topk_indices)
+ topk_weights = topk_weights * self.routed_scaling_factor
+ return topk_indices, topk_weights
+
+
+class LongcatFlashMoE(nn.Module):
+ """
+ A mixed expert module containing zero compute (identity) experts.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.intermediate_size = config.expert_ffn_hidden_size
+ self.config = config
+
+ self.experts = nn.ModuleList(
+ [LongcatFlashMLP(config, intermediate_size=self.intermediate_size) for _ in range(config.n_routed_experts)]
+ + [nn.Identity() for _ in range(config.zero_expert_num)]
+ )
+
+ self.router = LongcatFlashTopkRouter(config)
+
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
+ r"""
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
+ """
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
+ expert_mask = expert_mask.permute(2, 0, 1)
+
+ for expert_idx in range(len(self.experts)):
+ expert = self.experts[expert_idx]
+ mask = expert_mask[expert_idx]
+ token_indices, weight_indices = torch.where(mask)
+
+ if token_indices.numel() > 0:
+ expert_weights = topk_weights[token_indices, weight_indices]
+ expert_input = hidden_states[token_indices]
+ expert_output = expert(expert_input)
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
+
+ # in original deepseek, the output of the experts are gathered once we leave this module
+ # thus the moe module is itelsf an IsolatedParallel module
+ # and all expert are "local" meaning we shard but we don't gather
+ return final_hidden_states.type(hidden_states.dtype)
+
+ def forward(self, hidden_states):
+ orig_shape = hidden_states.shape
+ topk_indices, topk_weights = self.router(hidden_states)
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
+ return hidden_states
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ r"""
+ TODO let's just use the original freqcis computation to not have the view
+ transpose + reshape! This is not optimized!
+ Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ b, h, s, d = q.shape
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
+
+ b, h, s, d = k.shape
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def yarn_get_mscale(scale=1, mscale=1):
+ if scale <= 1:
+ return 1.0
+ return 0.1 * mscale * math.log(scale) + 1.0
+
+
+class LongcatFlashMLA(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.attention_dropout = config.attention_dropout
+ self.num_heads = config.num_attention_heads
+ self.rope_theta = config.rope_theta
+ self.q_lora_rank = config.q_lora_rank
+ self.qk_rope_head_dim = config.qk_rope_head_dim
+ self.kv_lora_rank = config.kv_lora_rank
+ self.v_head_dim = config.v_head_dim
+ self.qk_nope_head_dim = config.qk_nope_head_dim
+ self.qk_head_dim = config.qk_head_dim
+
+ self.is_causal = True
+ if self.q_lora_rank is None:
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
+ else:
+ self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
+ self.q_a_layernorm = LongcatFlashRMSNorm(config.q_lora_rank)
+ self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
+
+ self.kv_a_proj_with_mqa = nn.Linear(
+ config.hidden_size,
+ self.kv_lora_rank + self.qk_rope_head_dim,
+ bias=config.attention_bias,
+ )
+ self.kv_a_layernorm = LongcatFlashRMSNorm(self.kv_lora_rank)
+ self.kv_b_proj = nn.Linear(
+ self.kv_lora_rank,
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
+ bias=False,
+ )
+
+ self.o_proj = nn.Linear(
+ self.num_heads * self.v_head_dim,
+ config.hidden_size,
+ bias=config.attention_bias,
+ )
+
+ self.scaling = self.qk_head_dim ** (-0.5)
+ if self.config.rope_scaling is not None:
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
+ scaling_factor = self.config.rope_scaling["factor"]
+ if mscale_all_dim:
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
+ self.scaling = self.scaling * mscale * mscale
+
+ self.mla_scale_q_lora = (config.hidden_size / self.q_lora_rank) ** 0.5
+ self.mla_scale_kv_lora = (config.hidden_size / self.kv_lora_rank) ** 0.5
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ batch_size, seq_length = hidden_states.shape[:-1]
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
+ # we always do a lora for queries as well
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
+ q_states = q_states.view(query_shape).transpose(1, 2)
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
+ k_pass = self.kv_a_layernorm(k_pass)
+
+ # apply LoRA scaling
+ q_pass = q_pass * self.mla_scale_q_lora
+ q_rot = q_rot * self.mla_scale_q_lora
+ k_pass = k_pass * self.mla_scale_kv_lora
+
+ k_pass = self.kv_b_proj(k_pass).view(key_shape).transpose(1, 2)
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
+
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
+
+ cos, sin = position_embeddings
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
+
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
+ value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
+
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class LongcatFlashDecoderLayer(GradientCheckpointingLayer):
+ """
+ LongCat decoder layer with dual-sublayer + shortcut MoE architecture.
+
+ Each logical layer contains:
+ - 2 attention sublayers (with layer indices: layer_idx*2, layer_idx*2+1)
+ - 2 MLP sublayers
+ - 1 shortcut MoE connection
+ """
+
+ def __init__(self, config, layer_idx: int):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+
+ self.mlp = LongcatFlashMoE(config)
+
+ self.self_attn = nn.ModuleList([LongcatFlashMLA(config=config, layer_idx=layer_idx * 2 + i) for i in [0, 1]])
+ self.mlps = nn.ModuleList([LongcatFlashMLP(config) for _ in [0, 1]])
+ self.input_layernorm = nn.ModuleList(
+ [LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in [0, 1]]
+ )
+ self.post_attention_layernorm = nn.ModuleList(
+ [LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in [0, 1]]
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm[0](hidden_states)
+
+ hidden_states, _ = self.self_attn[0](
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm[0](hidden_states)
+
+ shortcut_mlp_output = self.mlp(hidden_states)
+ hidden_states = self.mlps[0](hidden_states)
+ hidden_states = residual + hidden_states
+
+ # shortcut connection after second sublayer
+ residual = hidden_states
+ hidden_states = self.input_layernorm[1](hidden_states)
+
+ hidden_states, _ = self.self_attn[1](
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm[1](hidden_states)
+
+ hidden_states = self.mlps[1](hidden_states)
+ hidden_states = residual + hidden_states + shortcut_mlp_output
+
+ return hidden_states
+
+
+@auto_docstring
+class LongcatFlashPreTrainedModel(PreTrainedModel):
+ config: LongcatFlashConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LongcatFlashDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _can_compile_fullgraph = False
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": LongcatFlashDecoderLayer,
+ "attentions": LongcatFlashMLA,
+ }
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, LongcatFlashTopkRouter):
+ module.classifier.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+
+@auto_docstring
+class LongcatFlashModel(LongcatFlashPreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [LongcatFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_layers)]
+ )
+ self.norm = LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = LongcatFlashRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+ # Each layer above has 2 sublayers, config hack to have a correct cache (to avoid a checkpoint change)
+ self.head_dim = config.head_dim # For CI happiness (we didn't convert so head_dim is not directly used)
+
+ self.config.num_hidden_layers = 2 * config.num_layers
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=None,
+ attentions=None,
+ )
+
+
+@auto_docstring
+class LongcatFlashForCausalLM(LongcatFlashPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+ _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LongcatFlashModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LongcatFlashForCausalLM
+
+ >>> model = LongcatFlashForCausalLM.from_pretrained("meta-longcat_flash/LongcatFlash-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-longcat_flash/LongcatFlash-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["LongcatFlashPreTrainedModel", "LongcatFlashModel", "LongcatFlashForCausalLM"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longcat_flash/modular_longcat_flash.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longcat_flash/modular_longcat_flash.py
new file mode 100644
index 0000000000000000000000000000000000000000..60c93239d2c4d7eb3a4f5bc787b73d4155eb3e10
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longcat_flash/modular_longcat_flash.py
@@ -0,0 +1,382 @@
+# coding=utf-8
+# Copyright 2025 Meituan and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...cache_utils import Cache, DynamicCache
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, logging
+from ..deepseek_v3.modeling_deepseek_v3 import (
+ DeepseekV3Attention,
+ DeepseekV3ForCausalLM,
+ DeepseekV3MLP,
+ DeepseekV3Model,
+ DeepseekV3MoE,
+ DeepseekV3PreTrainedModel,
+ DeepseekV3RMSNorm,
+ DeepseekV3RotaryEmbedding,
+ DeepseekV3TopkRouter,
+ apply_rotary_pos_emb_interleave,
+ eager_attention_forward,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class LongcatFlashRMSNorm(DeepseekV3RMSNorm):
+ pass
+
+
+class LongcatFlashRotaryEmbedding(DeepseekV3RotaryEmbedding):
+ pass
+
+
+# TODO remap config key ffn_hidden_size -> intermediate_size
+class LongcatFlashMLP(DeepseekV3MLP):
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
+ super().__init__()
+ self.intermediate_size = config.ffn_hidden_size if intermediate_size is None else intermediate_size
+
+
+# TODO remap config key moe_topk -> num_experts_per_tok
+class LongcatFlashTopkRouter(DeepseekV3TopkRouter):
+ def __init__(self, config):
+ super().__init__(config)
+ del self.n_group
+ del self.topk_group
+ del self.weight
+ del self.norm_topk_prob
+
+ self.top_k = config.moe_topk
+ self.n_routed_experts = config.n_routed_experts + (config.zero_expert_num or 0)
+ self.routed_scaling_factor = config.routed_scaling_factor
+ self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts))
+ self.router_bias = getattr(config, "router_bias", False)
+ self.classifier = nn.Linear(config.hidden_size, self.n_routed_experts, bias=self.router_bias)
+
+ @torch.no_grad()
+ def get_topk_indices(self, scores):
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
+ return topk_indices
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
+ router_logits = F.linear(hidden_states.type(torch.float32), self.classifier.weight.type(torch.float32))
+ scores = router_logits.softmax(dim=-1)
+ topk_indices = self.get_topk_indices(scores)
+ topk_weights = scores.gather(1, topk_indices)
+ topk_weights = topk_weights * self.routed_scaling_factor
+ return topk_indices, topk_weights
+
+
+# remap config key expert_ffn_hidden_size -> moe_intermediate_size
+class LongcatFlashMoE(DeepseekV3MoE):
+ """
+ A mixed expert module containing zero compute (identity) experts.
+ """
+
+ def __init__(self, config):
+ self.intermediate_size = config.expert_ffn_hidden_size
+ super().__init__(config)
+ del self.gate
+ del self.shared_experts
+
+ self.experts = nn.ModuleList(
+ [LongcatFlashMLP(config, intermediate_size=self.intermediate_size) for _ in range(config.n_routed_experts)]
+ + [nn.Identity() for _ in range(config.zero_expert_num)]
+ )
+
+ self.router = LongcatFlashTopkRouter(config)
+
+ def forward(self, hidden_states):
+ orig_shape = hidden_states.shape
+ topk_indices, topk_weights = self.router(hidden_states)
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
+ return hidden_states
+
+
+class LongcatFlashMLA(DeepseekV3Attention):
+ def __init__(self, config, layer_idx: int):
+ super().__init__(config, layer_idx)
+
+ self.mla_scale_q_lora = (config.hidden_size / self.q_lora_rank) ** 0.5
+ self.mla_scale_kv_lora = (config.hidden_size / self.kv_lora_rank) ** 0.5
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ batch_size, seq_length = hidden_states.shape[:-1]
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
+ # we always do a lora for queries as well
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
+ q_states = q_states.view(query_shape).transpose(1, 2)
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
+ k_pass = self.kv_a_layernorm(k_pass)
+
+ # apply LoRA scaling
+ q_pass = q_pass * self.mla_scale_q_lora
+ q_rot = q_rot * self.mla_scale_q_lora
+ k_pass = k_pass * self.mla_scale_kv_lora
+
+ k_pass = self.kv_b_proj(k_pass).view(key_shape).transpose(1, 2)
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
+
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
+
+ cos, sin = position_embeddings
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
+
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
+ value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
+
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class LongcatFlashDecoderLayer(GradientCheckpointingLayer):
+ """
+ LongCat decoder layer with dual-sublayer + shortcut MoE architecture.
+
+ Each logical layer contains:
+ - 2 attention sublayers (with layer indices: layer_idx*2, layer_idx*2+1)
+ - 2 MLP sublayers
+ - 1 shortcut MoE connection
+ """
+
+ def __init__(self, config, layer_idx: int):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+
+ self.mlp = LongcatFlashMoE(config)
+
+ self.self_attn = nn.ModuleList([LongcatFlashMLA(config=config, layer_idx=layer_idx * 2 + i) for i in [0, 1]])
+ self.mlps = nn.ModuleList([LongcatFlashMLP(config) for _ in [0, 1]])
+ self.input_layernorm = nn.ModuleList(
+ [LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in [0, 1]]
+ )
+ self.post_attention_layernorm = nn.ModuleList(
+ [LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in [0, 1]]
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm[0](hidden_states)
+
+ hidden_states, _ = self.self_attn[0](
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm[0](hidden_states)
+
+ shortcut_mlp_output = self.mlp(hidden_states)
+ hidden_states = self.mlps[0](hidden_states)
+ hidden_states = residual + hidden_states
+
+ # shortcut connection after second sublayer
+ residual = hidden_states
+ hidden_states = self.input_layernorm[1](hidden_states)
+
+ hidden_states, _ = self.self_attn[1](
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm[1](hidden_states)
+
+ hidden_states = self.mlps[1](hidden_states)
+ hidden_states = residual + hidden_states + shortcut_mlp_output
+
+ return hidden_states
+
+
+class LongcatFlashPreTrainedModel(DeepseekV3PreTrainedModel):
+ _can_record_outputs = {
+ "hidden_states": LongcatFlashDecoderLayer,
+ "attentions": LongcatFlashMLA,
+ }
+
+ def _init_weights(self, module):
+ PreTrainedModel._init_weights(self, module)
+ if isinstance(module, LongcatFlashTopkRouter):
+ module.classifier.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+
+class LongcatFlashModel(DeepseekV3Model):
+ _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.layers = nn.ModuleList(
+ [LongcatFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_layers)]
+ )
+ # Each layer above has 2 sublayers, config hack to have a correct cache (to avoid a checkpoint change)
+ self.head_dim = config.head_dim # For CI happiness (we didn't convert so head_dim is not directly used)
+
+ self.config.num_hidden_layers = 2 * config.num_layers
+ self.norm = LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = LongcatFlashRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=None,
+ attentions=None,
+ )
+
+
+class LongcatFlashForCausalLM(DeepseekV3ForCausalLM):
+ _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LongcatFlashModel(config)
+
+
+__all__ = ["LongcatFlashPreTrainedModel", "LongcatFlashModel", "LongcatFlashForCausalLM"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..87f53105424b76e5c18bd740ecfdd37a5b29d0d4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_longformer import *
+ from .modeling_longformer import *
+ from .modeling_tf_longformer import *
+ from .tokenization_longformer import *
+ from .tokenization_longformer_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/configuration_longformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/configuration_longformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..207cc18394796352a0d42bf28082baf04d11ff8e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/configuration_longformer.py
@@ -0,0 +1,207 @@
+# coding=utf-8
+# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Longformer configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import TensorType, logging
+
+
+if TYPE_CHECKING:
+ from ...onnx.config import PatchingSpec
+ from ...tokenization_utils_base import PreTrainedTokenizerBase
+
+
+logger = logging.get_logger(__name__)
+
+
+class LongformerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LongformerModel`] or a [`TFLongformerModel`]. It
+ is used to instantiate a Longformer model according to the specified arguments, defining the model architecture.
+
+ This is the configuration class to store the configuration of a [`LongformerModel`]. It is used to instantiate an
+ Longformer model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the LongFormer
+ [allenai/longformer-base-4096](https://huggingface.co/allenai/longformer-base-4096) architecture with a sequence
+ length 4,096.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the Longformer model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`LongformerModel`] or [`TFLongformerModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`LongformerModel`] or
+ [`TFLongformerModel`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ attention_window (`int` or `list[int]`, *optional*, defaults to 512):
+ Size of an attention window around each token. If an `int`, use the same size for all layers. To specify a
+ different window size for each layer, use a `list[int]` where `len(attention_window) == num_hidden_layers`.
+
+ Example:
+
+ ```python
+ >>> from transformers import LongformerConfig, LongformerModel
+
+ >>> # Initializing a Longformer configuration
+ >>> configuration = LongformerConfig()
+
+ >>> # Initializing a model from the configuration
+ >>> model = LongformerModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "longformer"
+
+ def __init__(
+ self,
+ attention_window: Union[list[int], int] = 512,
+ sep_token_id: int = 2,
+ pad_token_id: int = 1,
+ bos_token_id: int = 0,
+ eos_token_id: int = 2,
+ vocab_size: int = 30522,
+ hidden_size: int = 768,
+ num_hidden_layers: int = 12,
+ num_attention_heads: int = 12,
+ intermediate_size: int = 3072,
+ hidden_act: str = "gelu",
+ hidden_dropout_prob: float = 0.1,
+ attention_probs_dropout_prob: float = 0.1,
+ max_position_embeddings: int = 512,
+ type_vocab_size: int = 2,
+ initializer_range: float = 0.02,
+ layer_norm_eps: float = 1e-12,
+ onnx_export: bool = False,
+ **kwargs,
+ ):
+ """Constructs LongformerConfig."""
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+ self.attention_window = attention_window
+ self.sep_token_id = sep_token_id
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.onnx_export = onnx_export
+
+
+class LongformerOnnxConfig(OnnxConfig):
+ def __init__(
+ self, config: "PretrainedConfig", task: str = "default", patching_specs: "Optional[list[PatchingSpec]]" = None
+ ):
+ super().__init__(config, task, patching_specs)
+ config.onnx_export = True
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ("global_attention_mask", dynamic_axis),
+ ]
+ )
+
+ @property
+ def outputs(self) -> Mapping[str, Mapping[int, str]]:
+ outputs = super().outputs
+ if self.task == "default":
+ outputs["pooler_output"] = {0: "batch"}
+ return outputs
+
+ @property
+ def atol_for_validation(self) -> float:
+ """
+ What absolute tolerance value to use during model conversion validation.
+
+ Returns:
+ Float absolute tolerance value.
+ """
+ return 1e-4
+
+ @property
+ def default_onnx_opset(self) -> int:
+ # needs to be >= 14 to support tril operator
+ return max(super().default_onnx_opset, 14)
+
+ def generate_dummy_inputs(
+ self,
+ tokenizer: "PreTrainedTokenizerBase",
+ batch_size: int = -1,
+ seq_length: int = -1,
+ is_pair: bool = False,
+ framework: Optional[TensorType] = None,
+ ) -> Mapping[str, Any]:
+ inputs = super().generate_dummy_inputs(
+ preprocessor=tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
+ )
+ import torch
+
+ # for some reason, replacing this code by inputs["global_attention_mask"] = torch.randint(2, inputs["input_ids"].shape, dtype=torch.int64)
+ # makes the export fail randomly
+ inputs["global_attention_mask"] = torch.zeros_like(inputs["input_ids"])
+ # make every second token global
+ inputs["global_attention_mask"][:, ::2] = 1
+
+ return inputs
+
+
+__all__ = ["LongformerConfig", "LongformerOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/modeling_longformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/modeling_longformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdc7089249674b901662d644f94ae8ee8521d625
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/modeling_longformer.py
@@ -0,0 +1,2222 @@
+# coding=utf-8
+# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Longformer model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN, gelu
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import ModelOutput, auto_docstring, logging
+from .configuration_longformer import LongformerConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Longformer's outputs, with potential hidden states, local and global attentions.
+ """
+)
+class LongformerBaseModelOutput(ModelOutput):
+ r"""
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
+ attention_window + 1)`, where `x` is the number of tokens with global attention mask.
+
+ Local attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token in the sequence to every token with
+ global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
+ remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
+ token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
+ (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
+ If the attention window contains a token with global attention, the attention weight at the corresponding
+ index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
+ attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
+ accessed from `global_attentions`.
+ global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
+ where `x` is the number of tokens with global attention mask.
+
+ Global attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token with global attention to every token
+ in the sequence.
+ """
+
+ last_hidden_state: torch.FloatTensor
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Longformer's outputs that also contains a pooling of the last hidden states.
+ """
+)
+class LongformerBaseModelOutputWithPooling(ModelOutput):
+ r"""
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token) further processed by a
+ Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
+ prediction (classification) objective during pretraining.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
+ attention_window + 1)`, where `x` is the number of tokens with global attention mask.
+
+ Local attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token in the sequence to every token with
+ global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
+ remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
+ token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
+ (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
+ If the attention window contains a token with global attention, the attention weight at the corresponding
+ index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
+ attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
+ accessed from `global_attentions`.
+ global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
+ where `x` is the number of tokens with global attention mask.
+
+ Global attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token with global attention to every token
+ in the sequence.
+ """
+
+ last_hidden_state: torch.FloatTensor
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for masked language models outputs.
+ """
+)
+class LongformerMaskedLMOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Masked language modeling (MLM) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
+ attention_window + 1)`, where `x` is the number of tokens with global attention mask.
+
+ Local attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token in the sequence to every token with
+ global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
+ remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
+ token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
+ (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
+ If the attention window contains a token with global attention, the attention weight at the corresponding
+ index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
+ attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
+ accessed from `global_attentions`.
+ global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
+ where `x` is the number of tokens with global attention mask.
+
+ Global attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token with global attention to every token
+ in the sequence.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of question answering Longformer models.
+ """
+)
+class LongformerQuestionAnsweringModelOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
+ attention_window + 1)`, where `x` is the number of tokens with global attention mask.
+
+ Local attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token in the sequence to every token with
+ global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
+ remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
+ token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
+ (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
+ If the attention window contains a token with global attention, the attention weight at the corresponding
+ index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
+ attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
+ accessed from `global_attentions`.
+ global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
+ where `x` is the number of tokens with global attention mask.
+
+ Global attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token with global attention to every token
+ in the sequence.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ start_logits: Optional[torch.FloatTensor] = None
+ end_logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of sentence classification models.
+ """
+)
+class LongformerSequenceClassifierOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
+ attention_window + 1)`, where `x` is the number of tokens with global attention mask.
+
+ Local attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token in the sequence to every token with
+ global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
+ remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
+ token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
+ (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
+ If the attention window contains a token with global attention, the attention weight at the corresponding
+ index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
+ attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
+ accessed from `global_attentions`.
+ global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
+ where `x` is the number of tokens with global attention mask.
+
+ Global attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token with global attention to every token
+ in the sequence.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of multiple choice Longformer models.
+ """
+)
+class LongformerMultipleChoiceModelOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
+ Classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
+ *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
+
+ Classification scores (before SoftMax).
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
+ attention_window + 1)`, where `x` is the number of tokens with global attention mask.
+
+ Local attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token in the sequence to every token with
+ global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
+ remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
+ token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
+ (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
+ If the attention window contains a token with global attention, the attention weight at the corresponding
+ index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
+ attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
+ accessed from `global_attentions`.
+ global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
+ where `x` is the number of tokens with global attention mask.
+
+ Global attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token with global attention to every token
+ in the sequence.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of token classification models.
+ """
+)
+class LongformerTokenClassifierOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
+ Classification scores (before SoftMax).
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
+ attention_window + 1)`, where `x` is the number of tokens with global attention mask.
+
+ Local attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token in the sequence to every token with
+ global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
+ remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
+ token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
+ (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
+ If the attention window contains a token with global attention, the attention weight at the corresponding
+ index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
+ attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
+ accessed from `global_attentions`.
+ global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
+ where `x` is the number of tokens with global attention mask.
+
+ Global attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token with global attention to every token
+ in the sequence.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ global_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+def _get_question_end_index(input_ids, sep_token_id):
+ """
+ Computes the index of the first occurrence of `sep_token_id`.
+ """
+
+ sep_token_indices = (input_ids == sep_token_id).nonzero()
+ batch_size = input_ids.shape[0]
+
+ assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
+ assert sep_token_indices.shape[0] == 3 * batch_size, (
+ f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You"
+ " might also consider to set `global_attention_mask` manually in the forward function to avoid this error."
+ )
+ return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
+
+
+def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=True):
+ """
+ Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is
+ True` else after `sep_token_id`.
+ """
+ question_end_index = _get_question_end_index(input_ids, sep_token_id)
+ question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
+ # bool attention mask with True in locations of global attention
+ attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
+ if before_sep_token is True:
+ attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.bool)
+ else:
+ # last token is separation token and should not be counted and in the middle are two separation tokens
+ attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.bool) * (
+ attention_mask.expand_as(input_ids) < input_ids.shape[-1]
+ ).to(torch.bool)
+
+ return attention_mask
+
+
+def create_position_ids_from_input_ids(input_ids, padding_idx):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+ are ignored. This is modified from fairseq's `utils.make_positions`.
+
+ Args:
+ x: torch.Tensor x:
+
+ Returns: torch.Tensor
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = input_ids.ne(padding_idx).int()
+ incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
+ return incremental_indices.long() + padding_idx
+
+
+class LongformerEmbeddings(nn.Module):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ self.padding_idx = config.pad_token_id
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
+ )
+
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device)
+ else:
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ position_embeddings = self.position_embeddings(position_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + position_embeddings + token_type_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+ """
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+
+ Args:
+ inputs_embeds: torch.Tensor inputs_embeds:
+
+ Returns: torch.Tensor
+ """
+ input_shape = inputs_embeds.size()[:-1]
+ sequence_length = input_shape[1]
+
+ position_ids = torch.arange(
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
+ )
+ return position_ids.unsqueeze(0).expand(input_shape)
+
+
+class LongformerSelfAttention(nn.Module):
+ def __init__(self, config, layer_id):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+ self.num_heads = config.num_attention_heads
+ self.head_dim = int(config.hidden_size / config.num_attention_heads)
+ self.embed_dim = config.hidden_size
+
+ self.query = nn.Linear(config.hidden_size, self.embed_dim)
+ self.key = nn.Linear(config.hidden_size, self.embed_dim)
+ self.value = nn.Linear(config.hidden_size, self.embed_dim)
+
+ # separate projection layers for tokens with global attention
+ self.query_global = nn.Linear(config.hidden_size, self.embed_dim)
+ self.key_global = nn.Linear(config.hidden_size, self.embed_dim)
+ self.value_global = nn.Linear(config.hidden_size, self.embed_dim)
+
+ self.dropout = config.attention_probs_dropout_prob
+
+ self.layer_id = layer_id
+ attention_window = config.attention_window[self.layer_id]
+ assert attention_window % 2 == 0, (
+ f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}"
+ )
+ assert attention_window > 0, (
+ f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}"
+ )
+
+ self.one_sided_attn_window_size = attention_window // 2
+
+ self.config = config
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ layer_head_mask=None,
+ is_index_masked=None,
+ is_index_global_attn=None,
+ is_global_attn=None,
+ output_attentions=False,
+ ):
+ """
+ [`LongformerSelfAttention`] expects *len(hidden_states)* to be multiple of *attention_window*. Padding to
+ *attention_window* happens in [`LongformerModel.forward`] to avoid redoing the padding on each layer.
+
+ The *attention_mask* is changed in [`LongformerModel.forward`] from 0, 1, 2 to:
+
+ - -10000: no attention
+ - 0: local attention
+ - +10000: global attention
+ """
+ hidden_states = hidden_states.transpose(0, 1)
+
+ # project hidden states
+ query_vectors = self.query(hidden_states)
+ key_vectors = self.key(hidden_states)
+ value_vectors = self.value(hidden_states)
+
+ seq_len, batch_size, embed_dim = hidden_states.size()
+ assert embed_dim == self.embed_dim, (
+ f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}"
+ )
+
+ # normalize query
+ query_vectors /= math.sqrt(self.head_dim)
+
+ query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
+ key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
+
+ attn_scores = self._sliding_chunks_query_key_matmul(
+ query_vectors, key_vectors, self.one_sided_attn_window_size
+ )
+
+ # values to pad for attention probs
+ remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
+
+ # cast to fp32/fp16 then replace 1's with -inf
+ float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
+ remove_from_windowed_attention_mask, torch.finfo(query_vectors.dtype).min
+ )
+ # diagonal mask with zeros everywhere and -inf inplace of padding
+ diagonal_mask = self._sliding_chunks_query_key_matmul(
+ float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
+ )
+
+ # pad local attention probs
+ attn_scores += diagonal_mask
+
+ assert list(attn_scores.size()) == [
+ batch_size,
+ seq_len,
+ self.num_heads,
+ self.one_sided_attn_window_size * 2 + 1,
+ ], (
+ f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
+ f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
+ )
+
+ # compute local attention probs from global attention keys and contact over window dim
+ if is_global_attn:
+ # compute global attn indices required through out forward fn
+ (
+ max_num_global_attn_indices,
+ is_index_global_attn_nonzero,
+ is_local_index_global_attn_nonzero,
+ is_local_index_no_global_attn_nonzero,
+ ) = self._get_global_attn_indices(is_index_global_attn)
+ # calculate global attn probs from global key
+
+ global_key_attn_scores = self._concat_with_global_key_attn_probs(
+ query_vectors=query_vectors,
+ key_vectors=key_vectors,
+ max_num_global_attn_indices=max_num_global_attn_indices,
+ is_index_global_attn_nonzero=is_index_global_attn_nonzero,
+ is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
+ is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
+ )
+ # concat to local_attn_probs
+ # (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
+ attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1)
+
+ # free memory
+ del global_key_attn_scores
+
+ attn_probs = nn.functional.softmax(
+ attn_scores, dim=-1, dtype=torch.float32
+ ) # use fp32 for numerical stability
+
+ if layer_head_mask is not None:
+ assert layer_head_mask.size() == (self.num_heads,), (
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ )
+ attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs
+
+ # softmax sometimes inserts NaN if all positions are masked, replace them with 0
+ attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
+ attn_probs = attn_probs.type_as(attn_scores)
+
+ # free memory
+ del attn_scores
+
+ # apply dropout
+ attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training)
+
+ value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
+
+ # compute local attention output with global attention value and add
+ if is_global_attn:
+ # compute sum of global and local attn
+ attn_output = self._compute_attn_output_with_global_indices(
+ value_vectors=value_vectors,
+ attn_probs=attn_probs,
+ max_num_global_attn_indices=max_num_global_attn_indices,
+ is_index_global_attn_nonzero=is_index_global_attn_nonzero,
+ is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
+ )
+ else:
+ # compute local attn only
+ attn_output = self._sliding_chunks_matmul_attn_probs_value(
+ attn_probs, value_vectors, self.one_sided_attn_window_size
+ )
+
+ assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"
+ attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
+
+ # compute value for global attention and overwrite to attention output
+ # TODO: remove the redundant computation
+ if is_global_attn:
+ global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
+ hidden_states=hidden_states,
+ max_num_global_attn_indices=max_num_global_attn_indices,
+ layer_head_mask=layer_head_mask,
+ is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
+ is_index_global_attn_nonzero=is_index_global_attn_nonzero,
+ is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
+ is_index_masked=is_index_masked,
+ )
+
+ # get only non zero global attn output
+ nonzero_global_attn_output = global_attn_output[
+ is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1]
+ ]
+
+ # overwrite values with global attention
+ attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view(
+ len(is_local_index_global_attn_nonzero[0]), -1
+ )
+ # The attention weights for tokens with global attention are
+ # just filler values, they were never used to compute the output.
+ # Fill with 0 now, the correct values are in 'global_attn_probs'.
+ attn_probs[is_index_global_attn_nonzero] = 0
+
+ outputs = (attn_output.transpose(0, 1),)
+
+ if output_attentions:
+ outputs += (attn_probs,)
+
+ return outputs + (global_attn_probs,) if (is_global_attn and output_attentions) else outputs
+
+ @staticmethod
+ def _pad_and_transpose_last_two_dims(hidden_states_padded, padding):
+ """pads rows and then flips rows and columns"""
+ hidden_states_padded = nn.functional.pad(
+ hidden_states_padded, padding
+ ) # padding value is not important because it will be overwritten
+ hidden_states_padded = hidden_states_padded.view(
+ *hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2)
+ )
+ return hidden_states_padded
+
+ @staticmethod
+ def _pad_and_diagonalize(chunked_hidden_states):
+ """
+ shift every row 1 step right, converting columns into diagonals.
+
+ Example:
+
+ ```python
+ chunked_hidden_states: [
+ 0.4983,
+ 2.6918,
+ -0.0071,
+ 1.0492,
+ -1.8348,
+ 0.7672,
+ 0.2986,
+ 0.0285,
+ -0.7584,
+ 0.4206,
+ -0.0405,
+ 0.1599,
+ 2.0514,
+ -1.1600,
+ 0.5372,
+ 0.2629,
+ ]
+ window_overlap = num_rows = 4
+ ```
+
+ (pad & diagonalize) => [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000
+ 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, 0.0000, -0.7584, 0.4206,
+ -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
+ """
+ total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size()
+ chunked_hidden_states = nn.functional.pad(
+ chunked_hidden_states, (0, window_overlap + 1)
+ ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
+ chunked_hidden_states = chunked_hidden_states.view(
+ total_num_heads, num_chunks, -1
+ ) # total_num_heads x num_chunks x window_overlap*window_overlap+window_overlap
+ chunked_hidden_states = chunked_hidden_states[
+ :, :, :-window_overlap
+ ] # total_num_heads x num_chunks x window_overlap*window_overlap
+ chunked_hidden_states = chunked_hidden_states.view(
+ total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim
+ )
+ chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
+ return chunked_hidden_states
+
+ @staticmethod
+ def _chunk(hidden_states, window_overlap, onnx_export: bool = False):
+ """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
+ if not onnx_export:
+ # non-overlapping chunks of size = 2w
+ hidden_states = hidden_states.view(
+ hidden_states.size(0),
+ torch.div(hidden_states.size(1), (window_overlap * 2), rounding_mode="trunc"),
+ window_overlap * 2,
+ hidden_states.size(2),
+ )
+ # use `as_strided` to make the chunks overlap with an overlap size = window_overlap
+ chunk_size = list(hidden_states.size())
+ chunk_size[1] = chunk_size[1] * 2 - 1
+
+ chunk_stride = list(hidden_states.stride())
+ chunk_stride[1] = chunk_stride[1] // 2
+ return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
+
+ # When exporting to ONNX, use this separate logic
+ # have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export
+
+ # TODO replace this with
+ # > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3)
+ # once `unfold` is supported
+ # the case hidden_states.size(1) == window_overlap * 2 can also simply return hidden_states.unsqueeze(1), but that's control flow
+
+ chunk_size = [
+ hidden_states.size(0),
+ torch.div(hidden_states.size(1), window_overlap, rounding_mode="trunc") - 1,
+ window_overlap * 2,
+ hidden_states.size(2),
+ ]
+
+ overlapping_chunks = torch.empty(chunk_size, device=hidden_states.device)
+ for chunk in range(chunk_size[1]):
+ overlapping_chunks[:, chunk, :, :] = hidden_states[
+ :, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, :
+ ]
+ return overlapping_chunks
+
+ @staticmethod
+ def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor:
+ beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
+ beginning_mask = beginning_mask_2d[None, :, None, :]
+ ending_mask = beginning_mask.flip(dims=(1, 3))
+ beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
+ beginning_mask = beginning_mask.expand(beginning_input.size())
+ input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] = torch.full_like(
+ beginning_input, -float("inf")
+ ).where(beginning_mask.bool(), beginning_input)
+ ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
+ ending_mask = ending_mask.expand(ending_input.size())
+ input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] = torch.full_like(
+ ending_input, -float("inf")
+ ).where(ending_mask.bool(), ending_input)
+
+ def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int):
+ """
+ Matrix multiplication of query and key tensors using with a sliding window attention pattern. This
+ implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an
+ overlap of size window_overlap
+ """
+ batch_size, seq_len, num_heads, head_dim = query.size()
+ assert seq_len % (window_overlap * 2) == 0, (
+ f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
+ )
+ assert query.size() == key.size()
+
+ chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
+
+ # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
+ query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
+ key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
+
+ query = self._chunk(query, window_overlap, getattr(self.config, "onnx_export", False))
+ key = self._chunk(key, window_overlap, getattr(self.config, "onnx_export", False))
+
+ # matrix multiplication
+ # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
+ # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
+ # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
+ diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
+
+ # convert diagonals into columns
+ diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(
+ diagonal_chunked_attention_scores, padding=(0, 0, 0, 1)
+ )
+
+ # allocate space for the overall attention matrix where the chunks are combined. The last dimension
+ # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to
+ # window_overlap previous words). The following column is attention score from each word to itself, then
+ # followed by window_overlap columns for the upper triangle.
+
+ diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros(
+ (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1)
+ )
+
+ # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions
+ # - copying the main diagonal and the upper triangle
+ diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
+ :, :, :window_overlap, : window_overlap + 1
+ ]
+ diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
+ :, -1, window_overlap:, : window_overlap + 1
+ ]
+ # - copying the lower triangle
+ diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
+ :, :, -(window_overlap + 1) : -1, window_overlap + 1 :
+ ]
+
+ diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
+ :, 0, : window_overlap - 1, 1 - window_overlap :
+ ]
+
+ # separate batch_size and num_heads dimensions again
+ diagonal_attention_scores = diagonal_attention_scores.view(
+ batch_size, num_heads, seq_len, 2 * window_overlap + 1
+ ).transpose(2, 1)
+
+ self._mask_invalid_locations(diagonal_attention_scores, window_overlap)
+ return diagonal_attention_scores
+
+ def _sliding_chunks_matmul_attn_probs_value(
+ self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int
+ ):
+ """
+ Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the
+ same shape as `attn_probs`
+ """
+ batch_size, seq_len, num_heads, head_dim = value.size()
+
+ assert seq_len % (window_overlap * 2) == 0
+ assert attn_probs.size()[:3] == value.size()[:3]
+ assert attn_probs.size(3) == 2 * window_overlap + 1
+ chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
+ # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
+
+ chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
+ batch_size * num_heads,
+ torch.div(seq_len, window_overlap, rounding_mode="trunc"),
+ window_overlap,
+ 2 * window_overlap + 1,
+ )
+
+ # group batch_size and num_heads dimensions into one
+ value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
+
+ # pad seq_len with w at the beginning of the sequence and another window overlap at the end
+ padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
+
+ # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
+ chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim)
+ chunked_value_stride = padded_value.stride()
+ chunked_value_stride = (
+ chunked_value_stride[0],
+ window_overlap * chunked_value_stride[1],
+ chunked_value_stride[1],
+ chunked_value_stride[2],
+ )
+ chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
+
+ chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
+
+ context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
+ return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
+
+ @staticmethod
+ def _get_global_attn_indices(is_index_global_attn):
+ """compute global attn indices required throughout forward pass"""
+ # helper variable
+ num_global_attn_indices = is_index_global_attn.long().sum(dim=1)
+
+ # max number of global attn indices in batch
+ max_num_global_attn_indices = num_global_attn_indices.max()
+
+ # indices of global attn
+ is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True)
+
+ # helper variable
+ is_local_index_global_attn = torch.arange(
+ max_num_global_attn_indices, device=is_index_global_attn.device
+ ) < num_global_attn_indices.unsqueeze(dim=-1)
+
+ # location of the non-padding values within global attention indices
+ is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True)
+
+ # location of the padding values within global attention indices
+ is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True)
+ return (
+ max_num_global_attn_indices,
+ is_index_global_attn_nonzero,
+ is_local_index_global_attn_nonzero,
+ is_local_index_no_global_attn_nonzero,
+ )
+
+ def _concat_with_global_key_attn_probs(
+ self,
+ key_vectors,
+ query_vectors,
+ max_num_global_attn_indices,
+ is_index_global_attn_nonzero,
+ is_local_index_global_attn_nonzero,
+ is_local_index_no_global_attn_nonzero,
+ ):
+ batch_size = key_vectors.shape[0]
+
+ # create only global key vectors
+ key_vectors_only_global = key_vectors.new_zeros(
+ batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
+ )
+
+ key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero]
+
+ # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
+ attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global))
+
+ # need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets
+ attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3)
+ attn_probs_from_global_key[
+ is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, :
+ ] = torch.finfo(attn_probs_from_global_key.dtype).min
+ attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3)
+
+ return attn_probs_from_global_key
+
+ def _compute_attn_output_with_global_indices(
+ self,
+ value_vectors,
+ attn_probs,
+ max_num_global_attn_indices,
+ is_index_global_attn_nonzero,
+ is_local_index_global_attn_nonzero,
+ ):
+ batch_size = attn_probs.shape[0]
+
+ # cut local attn probs to global only
+ attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices)
+ # get value vectors for global only
+ value_vectors_only_global = value_vectors.new_zeros(
+ batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
+ )
+ value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero]
+
+ # use `matmul` because `einsum` crashes sometimes with fp16
+ # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
+ # compute attn output only global
+ attn_output_only_global = torch.matmul(
+ attn_probs_only_global.transpose(1, 2).clone(), value_vectors_only_global.transpose(1, 2).clone()
+ ).transpose(1, 2)
+
+ # reshape attn probs
+ attn_probs_without_global = attn_probs.narrow(
+ -1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices
+ ).contiguous()
+
+ # compute attn output with global
+ attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
+ attn_probs_without_global, value_vectors, self.one_sided_attn_window_size
+ )
+ return attn_output_only_global + attn_output_without_global
+
+ def _compute_global_attn_output_from_hidden(
+ self,
+ hidden_states,
+ max_num_global_attn_indices,
+ layer_head_mask,
+ is_local_index_global_attn_nonzero,
+ is_index_global_attn_nonzero,
+ is_local_index_no_global_attn_nonzero,
+ is_index_masked,
+ ):
+ seq_len, batch_size = hidden_states.shape[:2]
+
+ # prepare global hidden states
+ global_attn_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices, batch_size, self.embed_dim)
+ global_attn_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[
+ is_index_global_attn_nonzero[::-1]
+ ]
+
+ # global key, query, value
+ global_query_vectors_only_global = self.query_global(global_attn_hidden_states)
+ global_key_vectors = self.key_global(hidden_states)
+ global_value_vectors = self.value_global(hidden_states)
+
+ # normalize
+ global_query_vectors_only_global /= math.sqrt(self.head_dim)
+
+ # reshape
+ global_query_vectors_only_global = (
+ global_query_vectors_only_global.contiguous()
+ .view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ ) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim)
+ global_key_vectors = (
+ global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
+ ) # batch_size * self.num_heads, seq_len, head_dim)
+ global_value_vectors = (
+ global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
+ ) # batch_size * self.num_heads, seq_len, head_dim)
+
+ # compute attn scores
+ global_attn_scores = torch.bmm(global_query_vectors_only_global, global_key_vectors.transpose(1, 2))
+
+ assert list(global_attn_scores.size()) == [
+ batch_size * self.num_heads,
+ max_num_global_attn_indices,
+ seq_len,
+ ], (
+ "global_attn_scores have the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
+ f" {global_attn_scores.size()}."
+ )
+
+ global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
+
+ # need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets
+ global_attn_scores = global_attn_scores.transpose(1, 2)
+ global_attn_scores[
+ is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, :
+ ] = torch.finfo(global_attn_scores.dtype).min
+ global_attn_scores = global_attn_scores.transpose(1, 2)
+
+ global_attn_scores = global_attn_scores.masked_fill(
+ is_index_masked[:, None, None, :],
+ torch.finfo(global_attn_scores.dtype).min,
+ )
+
+ global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
+
+ # compute global attn probs
+ global_attn_probs_float = nn.functional.softmax(
+ global_attn_scores, dim=-1, dtype=torch.float32
+ ) # use fp32 for numerical stability
+
+ # apply layer head masking
+ if layer_head_mask is not None:
+ assert layer_head_mask.size() == (self.num_heads,), (
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ )
+ global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view(
+ batch_size, self.num_heads, max_num_global_attn_indices, seq_len
+ )
+ global_attn_probs_float = global_attn_probs_float.view(
+ batch_size * self.num_heads, max_num_global_attn_indices, seq_len
+ )
+
+ global_attn_probs = nn.functional.dropout(
+ global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
+ )
+
+ # global attn output
+ global_attn_output = torch.bmm(global_attn_probs, global_value_vectors)
+
+ assert list(global_attn_output.size()) == [
+ batch_size * self.num_heads,
+ max_num_global_attn_indices,
+ self.head_dim,
+ ], (
+ "global_attn_output tensor has the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
+ f" {global_attn_output.size()}."
+ )
+
+ global_attn_probs = global_attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
+ global_attn_output = global_attn_output.view(
+ batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim
+ )
+ return global_attn_output, global_attn_probs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
+class LongformerSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class LongformerAttention(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.self = LongformerSelfAttention(config, layer_id)
+ self.output = LongformerSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ layer_head_mask=None,
+ is_index_masked=None,
+ is_index_global_attn=None,
+ is_global_attn=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ is_index_masked=is_index_masked,
+ is_index_global_attn=is_index_global_attn,
+ is_global_attn=is_global_attn,
+ output_attentions=output_attentions,
+ )
+ attn_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attn_output,) + self_outputs[1:]
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate
+class LongformerIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOutput
+class LongformerOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class LongformerLayer(GradientCheckpointingLayer):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.attention = LongformerAttention(config, layer_id)
+ self.intermediate = LongformerIntermediate(config)
+ self.output = LongformerOutput(config)
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ layer_head_mask=None,
+ is_index_masked=None,
+ is_index_global_attn=None,
+ is_global_attn=None,
+ output_attentions=False,
+ ):
+ self_attn_outputs = self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ is_index_masked=is_index_masked,
+ is_index_global_attn=is_index_global_attn,
+ is_global_attn=is_global_attn,
+ output_attentions=output_attentions,
+ )
+ attn_output = self_attn_outputs[0]
+ outputs = self_attn_outputs[1:]
+
+ layer_output = apply_chunking_to_forward(
+ self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attn_output
+ )
+ outputs = (layer_output,) + outputs
+ return outputs
+
+ def ff_chunk(self, attn_output):
+ intermediate_output = self.intermediate(attn_output)
+ layer_output = self.output(intermediate_output, attn_output)
+ return layer_output
+
+
+class LongformerEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ padding_len=0,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ is_index_masked = attention_mask < 0
+ is_index_global_attn = attention_mask > 0
+
+ # Record `is_global_attn == True` to enable ONNX export
+ is_global_attn = is_index_global_attn.flatten().any().item()
+
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None # All local attentions.
+ all_global_attentions = () if (output_attentions and is_global_attn) else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ assert head_mask.size()[0] == (len(self.layer)), (
+ f"The head_mask should be specified for {len(self.layer)} layers, but it is for {head_mask.size()[0]}."
+ )
+ for idx, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=head_mask[idx] if head_mask is not None else None,
+ is_index_masked=is_index_masked,
+ is_index_global_attn=is_index_global_attn,
+ is_global_attn=is_global_attn,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
+ all_attentions = all_attentions + (layer_outputs[1].transpose(1, 2),)
+
+ if is_global_attn:
+ # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
+ all_global_attentions = all_global_attentions + (layer_outputs[2].transpose(2, 3),)
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # undo padding if necessary
+ # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
+ hidden_states = hidden_states[:, : hidden_states.shape[1] - padding_len]
+ if output_hidden_states:
+ all_hidden_states = tuple(state[:, : state.shape[1] - padding_len] for state in all_hidden_states)
+
+ if output_attentions:
+ all_attentions = tuple(state[:, :, : state.shape[2] - padding_len, :] for state in all_attentions)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None
+ )
+ return LongformerBaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ global_attentions=all_global_attentions,
+ )
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPooler
+class LongformerPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead with Roberta->Longformer
+class LongformerLMHead(nn.Module):
+ """Longformer Head for masked language modeling."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+ self.decoder.bias = self.bias
+
+ def forward(self, features, **kwargs):
+ x = self.dense(features)
+ x = gelu(x)
+ x = self.layer_norm(x)
+
+ # project back to size of vocabulary with bias
+ x = self.decoder(x)
+
+ return x
+
+ def _tie_weights(self):
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
+ # For accelerate compatibility and to not break backward compatibility
+ if self.decoder.bias.device.type == "meta":
+ self.decoder.bias = self.bias
+ else:
+ self.bias = self.decoder.bias
+
+
+@auto_docstring
+class LongformerPreTrainedModel(PreTrainedModel):
+ config: LongformerConfig
+ base_model_prefix = "longformer"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LongformerSelfAttention"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class LongformerModel(LongformerPreTrainedModel):
+ """
+ This class copied code from [`RobertaModel`] and overwrote standard self-attention with longformer self-attention
+ to provide the ability to process long sequences following the self-attention approach described in [Longformer:
+ the Long-Document Transformer](https://huggingface.co/papers/2004.05150) by Iz Beltagy, Matthew E. Peters, and Arman Cohan.
+ Longformer self-attention combines a local (sliding window) and global attention to extend to long documents
+ without the O(n^2) increase in memory and compute.
+
+ The self-attention module `LongformerSelfAttention` implemented here supports the combination of local and global
+ attention but it lacks support for autoregressive attention and dilated attention. Autoregressive and dilated
+ attention are more relevant for autoregressive language modeling than finetuning on downstream tasks. Future
+ release will add support for autoregressive attention, but the support for dilated attention requires a custom CUDA
+ kernel to be memory and compute efficient.
+
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ r"""
+ add_pooling_layer (bool, *optional*, defaults to `True`):
+ Whether to add a pooling layer
+ """
+ super().__init__(config)
+ self.config = config
+
+ if isinstance(config.attention_window, int):
+ assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value"
+ assert config.attention_window > 0, "`config.attention_window` has to be positive"
+ config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer
+ else:
+ assert len(config.attention_window) == config.num_hidden_layers, (
+ "`len(config.attention_window)` should equal `config.num_hidden_layers`. "
+ f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}"
+ )
+
+ self.embeddings = LongformerEmbeddings(config)
+ self.encoder = LongformerEncoder(config)
+ self.pooler = LongformerPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def _pad_to_window_size(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ token_type_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ inputs_embeds: torch.Tensor,
+ pad_token_id: int,
+ ):
+ """A helper function to pad tokens and mask to work with implementation of Longformer self-attention."""
+ # padding
+ attention_window = (
+ self.config.attention_window
+ if isinstance(self.config.attention_window, int)
+ else max(self.config.attention_window)
+ )
+
+ assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
+ input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape
+ batch_size, seq_len = input_shape[:2]
+
+ padding_len = (attention_window - seq_len % attention_window) % attention_window
+
+ # this path should be recorded in the ONNX export, it is fine with padding_len == 0 as well
+ if padding_len > 0:
+ logger.warning_once(
+ f"Input ids are automatically padded to be a multiple of `config.attention_window`: {attention_window}"
+ )
+ if input_ids is not None:
+ input_ids = nn.functional.pad(input_ids, (0, padding_len), value=pad_token_id)
+ if position_ids is not None:
+ # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings
+ position_ids = nn.functional.pad(position_ids, (0, padding_len), value=pad_token_id)
+ if inputs_embeds is not None:
+ input_ids_padding = inputs_embeds.new_full(
+ (batch_size, padding_len),
+ self.config.pad_token_id,
+ dtype=torch.long,
+ )
+ inputs_embeds_padding = self.embeddings(input_ids_padding)
+ inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
+
+ attention_mask = nn.functional.pad(
+ attention_mask, (0, padding_len), value=0
+ ) # no attention on the padding tokens
+ token_type_ids = nn.functional.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0
+
+ return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
+
+ def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor):
+ # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
+ # (global_attention_mask + 1) => 1 for local attention, 2 for global attention
+ # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
+ if attention_mask is not None:
+ attention_mask = attention_mask * (global_attention_mask + 1)
+ else:
+ # simply use `global_attention_mask` as `attention_mask`
+ # if no `attention_mask` is given
+ attention_mask = global_attention_mask + 1
+ return attention_mask
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ global_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, LongformerBaseModelOutputWithPooling]:
+ r"""
+ global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to decide the attention given on each token, local attention or global attention. Tokens with global
+ attention attends to all other tokens, and all other tokens attend to them. This is important for
+ task-specific finetuning because it makes the model more flexible at representing the task. For example,
+ for classification, the token should be given global attention. For QA, all question tokens should also
+ have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more
+ details. Mask values selected in `[0, 1]`:
+
+ - 0 for local attention (a sliding window attention),
+ - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
+
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> from transformers import LongformerModel, AutoTokenizer
+
+ >>> model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
+
+ >>> SAMPLE_TEXT = " ".join(["Hello world! "] * 1000) # long input document
+ >>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1
+
+ >>> attention_mask = torch.ones(
+ ... input_ids.shape, dtype=torch.long, device=input_ids.device
+ ... ) # initialize to local attention
+ >>> global_attention_mask = torch.zeros(
+ ... input_ids.shape, dtype=torch.long, device=input_ids.device
+ ... ) # initialize to global attention to be deactivated for all tokens
+ >>> global_attention_mask[
+ ... :,
+ ... [
+ ... 1,
+ ... 4,
+ ... 21,
+ ... ],
+ ... ] = 1 # Set global attention to random tokens for the sake of this example
+ >>> # Usually, set global attention based on the task. For example,
+ >>> # classification: the token
+ >>> # QA: question tokens
+ >>> # LM: potentially on the beginning of sentences and paragraphs
+ >>> outputs = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)
+ >>> sequence_output = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output
+ ```
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # merge `global_attention_mask` and `attention_mask`
+ if global_attention_mask is not None:
+ attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
+
+ padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ pad_token_id=self.config.pad_token_id,
+ )
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)[
+ :, 0, 0, :
+ ]
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ padding_len=padding_len,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return LongformerBaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ global_attentions=encoder_outputs.global_attentions,
+ )
+
+
+@auto_docstring
+class LongformerForMaskedLM(LongformerPreTrainedModel):
+ _tied_weights_keys = ["lm_head.decoder"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.longformer = LongformerModel(config, add_pooling_layer=False)
+ self.lm_head = LongformerLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head.decoder = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ global_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, LongformerMaskedLMOutput]:
+ r"""
+ global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to decide the attention given on each token, local attention or global attention. Tokens with global
+ attention attends to all other tokens, and all other tokens attend to them. This is important for
+ task-specific finetuning because it makes the model more flexible at representing the task. For example,
+ for classification, the token should be given global attention. For QA, all question tokens should also
+ have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more
+ details. Mask values selected in `[0, 1]`:
+
+ - 0 for local attention (a sliding window attention),
+ - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+
+ Example Mask filling:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LongformerForMaskedLM
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
+ >>> model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
+ ```
+
+ Let's try a very long input.
+
+ ```python
+ >>> TXT = (
+ ... "My friends are but they eat too many carbs."
+ ... + " That's why I decide not to eat with them." * 300
+ ... )
+ >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
+ >>> logits = model(input_ids).logits
+
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
+ >>> probs = logits[0, masked_index].softmax(dim=0)
+ >>> values, predictions = probs.topk(5)
+
+ >>> tokenizer.decode(predictions).split()
+ ['healthy', 'skinny', 'thin', 'good', 'vegetarian']
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.longformer(
+ input_ids,
+ attention_mask=attention_mask,
+ global_attention_mask=global_attention_mask,
+ head_mask=head_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ prediction_scores = self.lm_head(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+
+ labels = labels.to(prediction_scores.device)
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return LongformerMaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ global_attentions=outputs.global_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Longformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """
+)
+class LongformerForSequenceClassification(LongformerPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.longformer = LongformerModel(config, add_pooling_layer=False)
+ self.classifier = LongformerClassificationHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ global_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, LongformerSequenceClassifierOutput]:
+ r"""
+ global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to decide the attention given on each token, local attention or global attention. Tokens with global
+ attention attends to all other tokens, and all other tokens attend to them. This is important for
+ task-specific finetuning because it makes the model more flexible at representing the task. For example,
+ for classification, the token should be given global attention. For QA, all question tokens should also
+ have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more
+ details. Mask values selected in `[0, 1]`:
+
+ - 0 for local attention (a sliding window attention),
+ - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if global_attention_mask is None:
+ logger.warning_once("Initializing global attention on CLS token...")
+ global_attention_mask = torch.zeros_like(input_ids)
+ # global attention on cls token
+ global_attention_mask[:, 0] = 1
+
+ outputs = self.longformer(
+ input_ids,
+ attention_mask=attention_mask,
+ global_attention_mask=global_attention_mask,
+ head_mask=head_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return LongformerSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ global_attentions=outputs.global_attentions,
+ )
+
+
+class LongformerClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(self, hidden_states, **kwargs):
+ hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS])
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.dense(hidden_states)
+ hidden_states = torch.tanh(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ output = self.out_proj(hidden_states)
+ return output
+
+
+@auto_docstring
+class LongformerForQuestionAnswering(LongformerPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.longformer = LongformerModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ global_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, LongformerQuestionAnsweringModelOutput]:
+ r"""
+ global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to decide the attention given on each token, local attention or global attention. Tokens with global
+ attention attends to all other tokens, and all other tokens attend to them. This is important for
+ task-specific finetuning because it makes the model more flexible at representing the task. For example,
+ for classification, the token should be given global attention. For QA, all question tokens should also
+ have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more
+ details. Mask values selected in `[0, 1]`:
+
+ - 0 for local attention (a sliding window attention),
+ - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LongformerForQuestionAnswering
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-large-4096-finetuned-triviaqa")
+ >>> model = LongformerForQuestionAnswering.from_pretrained("allenai/longformer-large-4096-finetuned-triviaqa")
+
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
+ >>> encoding = tokenizer(question, text, return_tensors="pt")
+ >>> input_ids = encoding["input_ids"]
+
+ >>> # default is local attention everywhere
+ >>> # the forward method will automatically set global attention on question tokens
+ >>> attention_mask = encoding["attention_mask"]
+
+ >>> outputs = model(input_ids, attention_mask=attention_mask)
+ >>> start_logits = outputs.start_logits
+ >>> end_logits = outputs.end_logits
+ >>> all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
+
+ >>> answer_tokens = all_tokens[torch.argmax(start_logits) : torch.argmax(end_logits) + 1]
+ >>> answer = tokenizer.decode(
+ ... tokenizer.convert_tokens_to_ids(answer_tokens)
+ ... ) # remove space prepending space token
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if global_attention_mask is None:
+ if input_ids is None:
+ logger.warning(
+ "It is not possible to automatically generate the `global_attention_mask` because input_ids is"
+ " None. Please make sure that it is correctly set."
+ )
+ else:
+ # set global attention on question tokens automatically
+ global_attention_mask = _compute_global_attention_mask(input_ids, self.config.sep_token_id)
+
+ outputs = self.longformer(
+ input_ids,
+ attention_mask=attention_mask,
+ global_attention_mask=global_attention_mask,
+ head_mask=head_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return LongformerQuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ global_attentions=outputs.global_attentions,
+ )
+
+
+@auto_docstring
+class LongformerForTokenClassification(LongformerPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.longformer = LongformerModel(config, add_pooling_layer=False)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ global_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, LongformerTokenClassifierOutput]:
+ r"""
+ global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to decide the attention given on each token, local attention or global attention. Tokens with global
+ attention attends to all other tokens, and all other tokens attend to them. This is important for
+ task-specific finetuning because it makes the model more flexible at representing the task. For example,
+ for classification, the token should be given global attention. For QA, all question tokens should also
+ have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more
+ details. Mask values selected in `[0, 1]`:
+
+ - 0 for local attention (a sliding window attention),
+ - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.longformer(
+ input_ids,
+ attention_mask=attention_mask,
+ global_attention_mask=global_attention_mask,
+ head_mask=head_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+
+ labels = labels.to(logits.device)
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return LongformerTokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ global_attentions=outputs.global_attentions,
+ )
+
+
+@auto_docstring
+class LongformerForMultipleChoice(LongformerPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.longformer = LongformerModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ global_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, LongformerMultipleChoiceModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ global_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+ Mask to decide the attention given on each token, local attention or global attention. Tokens with global
+ attention attends to all other tokens, and all other tokens attend to them. This is important for
+ task-specific finetuning because it makes the model more flexible at representing the task. For example,
+ for classification, the token should be given global attention. For QA, all question tokens should also
+ have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more
+ details. Mask values selected in `[0, 1]`:
+
+ - 0 for local attention (a sliding window attention),
+ - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ """
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # set global attention on question tokens
+ if global_attention_mask is None and input_ids is not None:
+ logger.warning_once("Initializing global attention on multiple choice...")
+ # put global attention on all tokens after `config.sep_token_id`
+ global_attention_mask = torch.stack(
+ [
+ _compute_global_attention_mask(input_ids[:, i], self.config.sep_token_id, before_sep_token=False)
+ for i in range(num_choices)
+ ],
+ dim=1,
+ )
+
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ flat_global_attention_mask = (
+ global_attention_mask.view(-1, global_attention_mask.size(-1))
+ if global_attention_mask is not None
+ else None
+ )
+ flat_inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.longformer(
+ flat_input_ids,
+ position_ids=flat_position_ids,
+ token_type_ids=flat_token_type_ids,
+ attention_mask=flat_attention_mask,
+ global_attention_mask=flat_global_attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+
+ labels = labels.to(reshaped_logits.device)
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return LongformerMultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ global_attentions=outputs.global_attentions,
+ )
+
+
+__all__ = [
+ "LongformerForMaskedLM",
+ "LongformerForMultipleChoice",
+ "LongformerForQuestionAnswering",
+ "LongformerForSequenceClassification",
+ "LongformerForTokenClassification",
+ "LongformerModel",
+ "LongformerPreTrainedModel",
+ "LongformerSelfAttention",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/modeling_tf_longformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/modeling_tf_longformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..891f5d76c95ce1f572ea33baf6db51cf21278a80
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/modeling_tf_longformer.py
@@ -0,0 +1,2783 @@
+# coding=utf-8
+# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tensorflow Longformer model."""
+
+from __future__ import annotations
+
+import warnings
+from dataclasses import dataclass
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_utils import (
+ TFMaskedLanguageModelingLoss,
+ TFModelInputType,
+ TFMultipleChoiceLoss,
+ TFPreTrainedModel,
+ TFQuestionAnsweringLoss,
+ TFSequenceClassificationLoss,
+ TFTokenClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
+from .configuration_longformer import LongformerConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "allenai/longformer-base-4096"
+_CONFIG_FOR_DOC = "LongformerConfig"
+
+LARGE_NEGATIVE = -1e8
+
+
+@dataclass
+class TFLongformerBaseModelOutput(ModelOutput):
+ """
+ Base class for Longformer's outputs, with potential hidden states, local and global attentions.
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
+ attention_window + 1)`, where `x` is the number of tokens with global attention mask.
+
+ Local attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token in the sequence to every token with
+ global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
+ remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
+ token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
+ (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
+ If the attention window contains a token with global attention, the attention weight at the corresponding
+ index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
+ attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
+ accessed from `global_attentions`.
+ global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
+ is the number of tokens with global attention mask.
+
+ Global attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token with global attention to every token
+ in the sequence.
+ """
+
+ last_hidden_state: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor, ...] | None = None
+ attentions: tuple[tf.Tensor, ...] | None = None
+ global_attentions: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFLongformerBaseModelOutputWithPooling(ModelOutput):
+ """
+ Base class for Longformer's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token) further processed by a
+ Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
+ prediction (classification) objective during pretraining.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
+ attention_window + 1)`, where `x` is the number of tokens with global attention mask.
+
+ Local attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token in the sequence to every token with
+ global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
+ remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
+ token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
+ (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
+ If the attention window contains a token with global attention, the attention weight at the corresponding
+ index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
+ attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
+ accessed from `global_attentions`.
+ global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
+ is the number of tokens with global attention mask.
+
+ Global attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token with global attention to every token
+ in the sequence.
+ """
+
+ last_hidden_state: tf.Tensor | None = None
+ pooler_output: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor, ...] | None = None
+ attentions: tuple[tf.Tensor, ...] | None = None
+ global_attentions: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFLongformerMaskedLMOutput(ModelOutput):
+ """
+ Base class for masked language models outputs.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Masked language modeling (MLM) loss.
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
+ attention_window + 1)`, where `x` is the number of tokens with global attention mask.
+
+ Local attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token in the sequence to every token with
+ global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
+ remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
+ token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
+ (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
+ If the attention window contains a token with global attention, the attention weight at the corresponding
+ index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
+ attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
+ accessed from `global_attentions`.
+ global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
+ is the number of tokens with global attention mask.
+
+ Global attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token with global attention to every token
+ in the sequence.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor, ...] | None = None
+ attentions: tuple[tf.Tensor, ...] | None = None
+ global_attentions: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFLongformerQuestionAnsweringModelOutput(ModelOutput):
+ """
+ Base class for outputs of question answering Longformer models.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+ start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Span-start scores (before SoftMax).
+ end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Span-end scores (before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
+ attention_window + 1)`, where `x` is the number of tokens with global attention mask.
+
+ Local attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token in the sequence to every token with
+ global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
+ remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
+ token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
+ (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
+ If the attention window contains a token with global attention, the attention weight at the corresponding
+ index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
+ attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
+ accessed from `global_attentions`.
+ global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
+ is the number of tokens with global attention mask.
+
+ Global attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token with global attention to every token
+ in the sequence.
+ """
+
+ loss: tf.Tensor | None = None
+ start_logits: tf.Tensor | None = None
+ end_logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor, ...] | None = None
+ attentions: tuple[tf.Tensor, ...] | None = None
+ global_attentions: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFLongformerSequenceClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of sentence classification models.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
+ attention_window + 1)`, where `x` is the number of tokens with global attention mask.
+
+ Local attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token in the sequence to every token with
+ global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
+ remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
+ token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
+ (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
+ If the attention window contains a token with global attention, the attention weight at the corresponding
+ index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
+ attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
+ accessed from `global_attentions`.
+ global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
+ is the number of tokens with global attention mask.
+
+ Global attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token with global attention to every token
+ in the sequence.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor, ...] | None = None
+ attentions: tuple[tf.Tensor, ...] | None = None
+ global_attentions: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFLongformerMultipleChoiceModelOutput(ModelOutput):
+ """
+ Base class for outputs of multiple choice models.
+
+ Args:
+ loss (`tf.Tensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
+ Classification loss.
+ logits (`tf.Tensor` of shape `(batch_size, num_choices)`):
+ *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
+
+ Classification scores (before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
+ attention_window + 1)`, where `x` is the number of tokens with global attention mask.
+
+ Local attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token in the sequence to every token with
+ global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
+ remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
+ token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
+ (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
+ If the attention window contains a token with global attention, the attention weight at the corresponding
+ index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
+ attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
+ accessed from `global_attentions`.
+ global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
+ is the number of tokens with global attention mask.
+
+ Global attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token with global attention to every token
+ in the sequence.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor, ...] | None = None
+ attentions: tuple[tf.Tensor, ...] | None = None
+ global_attentions: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFLongformerTokenClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of token classification models.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
+ Classification loss.
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`):
+ Classification scores (before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
+ attention_window + 1)`, where `x` is the number of tokens with global attention mask.
+
+ Local attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token in the sequence to every token with
+ global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
+ remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
+ token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
+ (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
+ If the attention window contains a token with global attention, the attention weight at the corresponding
+ index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
+ attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
+ accessed from `global_attentions`.
+ global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
+ is the number of tokens with global attention mask.
+
+ Global attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads. Those are the attention weights from every token with global attention to every token
+ in the sequence.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor, ...] | None = None
+ attentions: tuple[tf.Tensor, ...] | None = None
+ global_attentions: tuple[tf.Tensor, ...] | None = None
+
+
+def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_sep_token=True):
+ """
+ Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is
+ True` else after `sep_token_id`.
+ """
+ assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions"
+ question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1][:, None]
+ # bool attention mask with True in locations of global attention
+ attention_mask = tf.expand_dims(tf.range(input_ids_shape[1], dtype=tf.int64), axis=0)
+ attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1))
+ if before_sep_token is True:
+ question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1]))
+ attention_mask = tf.cast(attention_mask < question_end_index, dtype=question_end_index.dtype)
+ else:
+ # last token is separation token and should not be counted and in the middle are two separation tokens
+ question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
+ attention_mask = tf.cast(
+ attention_mask > question_end_index,
+ dtype=question_end_index.dtype,
+ ) * tf.cast(attention_mask < input_ids_shape[-1], dtype=question_end_index.dtype)
+
+ return attention_mask
+
+
+# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->Longformer
+class TFLongformerLMHead(keras.layers.Layer):
+ """Longformer Head for masked language modeling."""
+
+ def __init__(self, config, input_embeddings, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.dense = keras.layers.Dense(
+ config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+ self.act = get_tf_activation("gelu")
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = input_embeddings
+
+ def build(self, input_shape=None):
+ self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.hidden_size])
+
+ def get_output_embeddings(self):
+ return self.decoder
+
+ def set_output_embeddings(self, value):
+ self.decoder.weight = value
+ self.decoder.vocab_size = shape_list(value)[0]
+
+ def get_bias(self):
+ return {"bias": self.bias}
+
+ def set_bias(self, value):
+ self.bias = value["bias"]
+ self.config.vocab_size = shape_list(value["bias"])[0]
+
+ def call(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+
+ # project back to size of vocabulary with bias
+ seq_length = shape_list(tensor=hidden_states)[1]
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])
+ hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True)
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
+ hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
+
+ return hidden_states
+
+
+class TFLongformerEmbeddings(keras.layers.Layer):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing and some extra casting.
+ """
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.padding_idx = 1
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.max_position_embeddings = config.max_position_embeddings
+ self.initializer_range = config.initializer_range
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+
+ def build(self, input_shape=None):
+ with tf.name_scope("word_embeddings"):
+ self.weight = self.add_weight(
+ name="weight",
+ shape=[self.config.vocab_size, self.hidden_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ with tf.name_scope("token_type_embeddings"):
+ self.token_type_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.config.type_vocab_size, self.hidden_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ with tf.name_scope("position_embeddings"):
+ self.position_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.max_position_embeddings, self.hidden_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+ def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
+ symbols are ignored. This is modified from fairseq's `utils.make_positions`.
+
+ Args:
+ input_ids: tf.Tensor
+ Returns: tf.Tensor
+ """
+ mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
+ incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask
+
+ return incremental_indices + self.padding_idx
+
+ def call(
+ self,
+ input_ids=None,
+ position_ids=None,
+ token_type_ids=None,
+ inputs_embeds=None,
+ past_key_values_length=0,
+ training=False,
+ ):
+ """
+ Applies embedding based on inputs tensor.
+
+ Returns:
+ final_embeddings (`tf.Tensor`): output embedding tensor.
+ """
+ assert not (input_ids is None and inputs_embeds is None)
+
+ if input_ids is not None:
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
+ inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
+
+ input_shape = shape_list(inputs_embeds)[:-1]
+
+ if token_type_ids is None:
+ token_type_ids = tf.cast(tf.fill(dims=input_shape, value=0), tf.int64)
+
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = self.create_position_ids_from_input_ids(
+ input_ids=input_ids, past_key_values_length=past_key_values_length
+ )
+ else:
+ position_ids = tf.expand_dims(
+ tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1, dtype=tf.int64),
+ axis=0,
+ )
+
+ position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
+ token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
+ final_embeddings = inputs_embeds + position_embeds + token_type_embeds
+ final_embeddings = self.LayerNorm(inputs=final_embeddings)
+ final_embeddings = self.dropout(inputs=final_embeddings, training=training)
+
+ return final_embeddings
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Longformer
+class TFLongformerIntermediate(keras.layers.Layer):
+ def __init__(self, config: LongformerConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
+ else:
+ self.intermediate_act_fn = config.hidden_act
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Longformer
+class TFLongformerOutput(keras.layers.Layer):
+ def __init__(self, config: LongformerConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+ hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.intermediate_size])
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Longformer
+class TFLongformerPooler(keras.layers.Layer):
+ def __init__(self, config: LongformerConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation="tanh",
+ name="dense",
+ )
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(inputs=first_token_tensor)
+
+ return pooled_output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Longformer
+class TFLongformerSelfOutput(keras.layers.Layer):
+ def __init__(self, config: LongformerConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+ hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+class TFLongformerSelfAttention(keras.layers.Layer):
+ def __init__(self, config, layer_id, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads}"
+ )
+
+ self.num_heads = config.num_attention_heads
+ self.head_dim = int(config.hidden_size / config.num_attention_heads)
+ self.embed_dim = config.hidden_size
+ self.query = keras.layers.Dense(
+ self.embed_dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="query",
+ )
+ self.key = keras.layers.Dense(
+ self.embed_dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="key",
+ )
+ self.value = keras.layers.Dense(
+ self.embed_dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="value",
+ )
+
+ # separate projection layers for tokens with global attention
+ self.query_global = keras.layers.Dense(
+ self.embed_dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="query_global",
+ )
+ self.key_global = keras.layers.Dense(
+ self.embed_dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="key_global",
+ )
+ self.value_global = keras.layers.Dense(
+ self.embed_dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="value_global",
+ )
+ self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
+ self.global_dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
+ self.layer_id = layer_id
+ attention_window = config.attention_window[self.layer_id]
+
+ assert attention_window % 2 == 0, (
+ f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}"
+ )
+ assert attention_window > 0, (
+ f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}"
+ )
+
+ self.one_sided_attn_window_size = attention_window // 2
+
+ def build(self, input_shape=None):
+ if not self.built:
+ with tf.name_scope("query_global"):
+ self.query_global.build((self.config.hidden_size,))
+ with tf.name_scope("key_global"):
+ self.key_global.build((self.config.hidden_size,))
+ with tf.name_scope("value_global"):
+ self.value_global.build((self.config.hidden_size,))
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "query", None) is not None:
+ with tf.name_scope(self.query.name):
+ self.query.build([None, None, self.config.hidden_size])
+ if getattr(self, "key", None) is not None:
+ with tf.name_scope(self.key.name):
+ self.key.build([None, None, self.config.hidden_size])
+ if getattr(self, "value", None) is not None:
+ with tf.name_scope(self.value.name):
+ self.value.build([None, None, self.config.hidden_size])
+ if getattr(self, "query_global", None) is not None:
+ with tf.name_scope(self.query_global.name):
+ self.query_global.build([None, None, self.config.hidden_size])
+ if getattr(self, "key_global", None) is not None:
+ with tf.name_scope(self.key_global.name):
+ self.key_global.build([None, None, self.config.hidden_size])
+ if getattr(self, "value_global", None) is not None:
+ with tf.name_scope(self.value_global.name):
+ self.value_global.build([None, None, self.config.hidden_size])
+
+ def call(
+ self,
+ inputs,
+ training=False,
+ ):
+ """
+ LongformerSelfAttention expects *len(hidden_states)* to be multiple of *attention_window*. Padding to
+ *attention_window* happens in LongformerModel.forward to avoid redoing the padding on each layer.
+
+ The *attention_mask* is changed in [`LongformerModel.forward`] from 0, 1, 2 to:
+
+ - -10000: no attention
+ - 0: local attention
+ - +10000: global attention
+ """
+ # retrieve input args
+ (
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ is_index_masked,
+ is_index_global_attn,
+ is_global_attn,
+ ) = inputs
+
+ # project hidden states
+ query_vectors = self.query(hidden_states)
+ key_vectors = self.key(hidden_states)
+ value_vectors = self.value(hidden_states)
+ batch_size, seq_len, embed_dim = shape_list(hidden_states)
+
+ tf.debugging.assert_equal(
+ embed_dim,
+ self.embed_dim,
+ message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
+ )
+
+ # normalize query
+ query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
+ query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
+ key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
+
+ # attn_probs = (batch_size, seq_len, num_heads, window*2+1)
+ attn_scores = self._sliding_chunks_query_key_matmul(
+ query_vectors, key_vectors, self.one_sided_attn_window_size
+ )
+
+ # values to pad for attention probs
+ remove_from_windowed_attention_mask = attention_mask != 0
+ # cast to fp32/fp16 then replace 1's with -inf
+ float_mask = tf.cast(remove_from_windowed_attention_mask, dtype=query_vectors.dtype) * LARGE_NEGATIVE
+
+ # diagonal mask with zeros everywhere and -inf inplace of padding
+ diagonal_mask = self._sliding_chunks_query_key_matmul(
+ tf.ones(shape_list(attention_mask)),
+ float_mask,
+ self.one_sided_attn_window_size,
+ )
+
+ # pad local attention probs
+ attn_scores += diagonal_mask
+
+ tf.debugging.assert_equal(
+ shape_list(attn_scores),
+ [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
+ message=(
+ f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
+ f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
+ ),
+ )
+
+ # compute global attn indices required through out forward fn
+ (
+ max_num_global_attn_indices,
+ is_index_global_attn_nonzero,
+ is_local_index_global_attn_nonzero,
+ is_local_index_no_global_attn_nonzero,
+ ) = self._get_global_attn_indices(is_index_global_attn)
+
+ # this function is only relevant for global attention
+ if is_global_attn:
+ attn_scores = self._concat_with_global_key_attn_probs(
+ attn_scores=attn_scores,
+ query_vectors=query_vectors,
+ key_vectors=key_vectors,
+ max_num_global_attn_indices=max_num_global_attn_indices,
+ is_index_global_attn_nonzero=is_index_global_attn_nonzero,
+ is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
+ is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
+ )
+
+ attn_probs = stable_softmax(attn_scores, axis=-1)
+
+ # softmax sometimes inserts NaN if all positions are masked, replace them with 0
+ # Make sure to create a mask with the proper shape:
+ # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
+ # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
+ if is_global_attn:
+ masked_index = tf.tile(
+ is_index_masked[:, :, None, None],
+ (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
+ )
+ else:
+ masked_index = tf.tile(
+ is_index_masked[:, :, None, None],
+ (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
+ )
+ attn_probs = tf.where(
+ masked_index,
+ tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype),
+ attn_probs,
+ )
+
+ if layer_head_mask is not None:
+ tf.debugging.assert_equal(
+ shape_list(layer_head_mask),
+ [self.num_heads],
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
+ )
+
+ attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
+
+ # apply dropout
+ attn_probs = self.dropout(attn_probs, training=training)
+ value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
+
+ # if global attention, compute sum of global and local attn
+
+ if is_global_attn:
+ attn_output = self._compute_attn_output_with_global_indices(
+ value_vectors=value_vectors,
+ attn_probs=attn_probs,
+ max_num_global_attn_indices=max_num_global_attn_indices,
+ is_index_global_attn_nonzero=is_index_global_attn_nonzero,
+ is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
+ )
+ else:
+ attn_output = self._sliding_chunks_matmul_attn_probs_value(
+ attn_probs, value_vectors, self.one_sided_attn_window_size
+ )
+
+ tf.debugging.assert_equal(
+ shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
+ )
+
+ attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
+
+ # compute value for global attention and overwrite to attention output
+ if is_global_attn:
+ attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
+ attn_output=attn_output,
+ hidden_states=hidden_states,
+ max_num_global_attn_indices=max_num_global_attn_indices,
+ layer_head_mask=layer_head_mask,
+ is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
+ is_index_global_attn_nonzero=is_index_global_attn_nonzero,
+ is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
+ is_index_masked=is_index_masked,
+ training=training,
+ )
+ else:
+ # Leave attn_output unchanged
+ global_attn_probs = tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len))
+
+ # make sure that local attention probabilities are set to 0 for indices of global attn
+ # Make sure to create a mask with the proper shape:
+ # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
+ # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
+ if is_global_attn:
+ masked_global_attn_index = tf.tile(
+ is_index_global_attn[:, :, None, None],
+ (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
+ )
+ else:
+ masked_global_attn_index = tf.tile(
+ is_index_global_attn[:, :, None, None],
+ (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
+ )
+ attn_probs = tf.where(
+ masked_global_attn_index,
+ tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
+ attn_probs,
+ )
+
+ outputs = (attn_output, attn_probs, global_attn_probs)
+
+ return outputs
+
+ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
+ """
+ Matrix multiplication of query and key tensors using with a sliding window attention pattern. This
+ implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an
+ overlap of size window_overlap
+ """
+ batch_size, seq_len, num_heads, head_dim = shape_list(query)
+
+ tf.debugging.assert_equal(
+ seq_len % (window_overlap * 2),
+ 0,
+ message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
+ )
+ tf.debugging.assert_equal(
+ shape_list(query),
+ shape_list(key),
+ message=(
+ f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
+ f" {shape_list(key)}"
+ ),
+ )
+
+ chunks_count = seq_len // window_overlap - 1
+
+ # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
+ query = tf.reshape(
+ tf.transpose(query, (0, 2, 1, 3)),
+ (batch_size * num_heads, seq_len, head_dim),
+ )
+ key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
+ chunked_query = self._chunk(query, window_overlap)
+ chunked_key = self._chunk(key, window_overlap)
+
+ # matrix multiplication
+ # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
+ # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
+ # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
+ chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype)
+ chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply
+
+ # convert diagonals into columns
+ paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]])
+ diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings)
+
+ # allocate space for the overall attention matrix where the chunks are combined. The last dimension
+ # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to
+ # window_overlap previous words). The following column is attention score from each word to itself, then
+ # followed by window_overlap columns for the upper triangle.
+
+ # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions
+ # - copying the main diagonal and the upper triangle
+ # TODO: This code is most likely not very efficient and should be improved
+ diagonal_attn_scores_up_triang = tf.concat(
+ [
+ diagonal_chunked_attention_scores[:, :, :window_overlap, : window_overlap + 1],
+ diagonal_chunked_attention_scores[:, -1:, window_overlap:, : window_overlap + 1],
+ ],
+ axis=1,
+ )
+
+ # - copying the lower triangle
+ diagonal_attn_scores_low_triang = tf.concat(
+ [
+ tf.zeros(
+ (batch_size * num_heads, 1, window_overlap, window_overlap),
+ dtype=diagonal_chunked_attention_scores.dtype,
+ ),
+ diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :],
+ ],
+ axis=1,
+ )
+ diagonal_attn_scores_first_chunk = tf.concat(
+ [
+ tf.roll(
+ diagonal_chunked_attention_scores,
+ shift=[1, window_overlap],
+ axis=[2, 3],
+ )[:, :, :window_overlap, :window_overlap],
+ tf.zeros(
+ (batch_size * num_heads, 1, window_overlap, window_overlap),
+ dtype=diagonal_chunked_attention_scores.dtype,
+ ),
+ ],
+ axis=1,
+ )
+ first_chunk_mask = (
+ tf.tile(
+ tf.range(chunks_count + 1, dtype=tf.int64)[None, :, None, None],
+ (batch_size * num_heads, 1, window_overlap, window_overlap),
+ )
+ < 1
+ )
+ diagonal_attn_scores_low_triang = tf.where(
+ first_chunk_mask,
+ diagonal_attn_scores_first_chunk,
+ diagonal_attn_scores_low_triang,
+ )
+
+ # merging upper and lower triangle
+ diagonal_attention_scores = tf.concat(
+ [diagonal_attn_scores_low_triang, diagonal_attn_scores_up_triang], axis=-1
+ )
+
+ # separate batch_size and num_heads dimensions again
+ diagonal_attention_scores = tf.transpose(
+ tf.reshape(
+ diagonal_attention_scores,
+ (batch_size, num_heads, seq_len, 2 * window_overlap + 1),
+ ),
+ (0, 2, 1, 3),
+ )
+
+ diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap)
+
+ return diagonal_attention_scores
+
+ @staticmethod
+ def _mask_invalid_locations(input_tensor, window_overlap):
+ # create correct upper triangle bool mask
+ mask_2d_upper = tf.reverse(
+ tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
+ axis=[0],
+ )
+
+ # pad to full matrix
+ padding = tf.convert_to_tensor(
+ [[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]]
+ )
+
+ # create lower mask
+ mask_2d = tf.pad(mask_2d_upper, padding)
+
+ # combine with upper mask
+ mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
+
+ # broadcast to full matrix
+ mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
+
+ # inf tensor used for masking
+ inf_tensor = -float("inf") * tf.ones_like(input_tensor)
+
+ # mask
+ input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor)
+
+ return input_tensor
+
+ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap):
+ """
+ Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the
+ same shape as `attn_probs`
+ """
+
+ batch_size, seq_len, num_heads, head_dim = shape_list(value)
+
+ tf.debugging.assert_equal(
+ seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
+ )
+ tf.debugging.assert_equal(
+ shape_list(attn_probs)[:3],
+ shape_list(value)[:3],
+ message="value and attn_probs must have same dims (except head_dim)",
+ )
+ tf.debugging.assert_equal(
+ shape_list(attn_probs)[3],
+ 2 * window_overlap + 1,
+ message="attn_probs last dim has to be 2 * window_overlap + 1",
+ )
+
+ chunks_count = seq_len // window_overlap - 1
+
+ # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
+ chunked_attn_probs = tf.reshape(
+ tf.transpose(attn_probs, (0, 2, 1, 3)),
+ (
+ batch_size * num_heads,
+ seq_len // window_overlap,
+ window_overlap,
+ 2 * window_overlap + 1,
+ ),
+ )
+
+ # group batch_size and num_heads dimensions into one
+ value = tf.reshape(
+ tf.transpose(value, (0, 2, 1, 3)),
+ (batch_size * num_heads, seq_len, head_dim),
+ )
+
+ # pad seq_len with w at the beginning of the sequence and another window overlap at the end
+ paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]])
+ padded_value = tf.pad(value, paddings, constant_values=-1)
+
+ # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
+ frame_size = 3 * window_overlap * head_dim
+ frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
+ chunked_value = tf.signal.frame(
+ tf.reshape(padded_value, (batch_size * num_heads, -1)),
+ frame_size,
+ frame_hop_size,
+ )
+ chunked_value = tf.reshape(
+ chunked_value,
+ (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
+ )
+
+ tf.debugging.assert_equal(
+ shape_list(chunked_value),
+ [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
+ message="Chunked value has the wrong shape",
+ )
+
+ chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
+ context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
+ context = tf.transpose(
+ tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)),
+ (0, 2, 1, 3),
+ )
+
+ return context
+
+ @staticmethod
+ def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings):
+ """pads rows and then flips rows and columns"""
+ hidden_states_padded = tf.pad(
+ hidden_states_padded, paddings
+ ) # padding value is not important because it will be overwritten
+ batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
+ hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
+
+ return hidden_states_padded
+
+ @staticmethod
+ def _pad_and_diagonalize(chunked_hidden_states):
+ """
+ shift every row 1 step right, converting columns into diagonals.
+
+ Example:
+
+ ```python
+ chunked_hidden_states: [
+ 0.4983,
+ 2.6918,
+ -0.0071,
+ 1.0492,
+ -1.8348,
+ 0.7672,
+ 0.2986,
+ 0.0285,
+ -0.7584,
+ 0.4206,
+ -0.0405,
+ 0.1599,
+ 2.0514,
+ -1.1600,
+ 0.5372,
+ 0.2629,
+ ]
+ window_overlap = num_rows = 4
+ ```
+
+ (pad & diagonalize) => [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000
+ 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, 0.0000, -0.7584, 0.4206,
+ -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
+ """
+ total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states)
+ paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
+ chunked_hidden_states = tf.pad(
+ chunked_hidden_states, paddings
+ ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
+ chunked_hidden_states = tf.reshape(
+ chunked_hidden_states, (total_num_heads, num_chunks, -1)
+ ) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap
+ chunked_hidden_states = chunked_hidden_states[
+ :, :, :-window_overlap
+ ] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
+ chunked_hidden_states = tf.reshape(
+ chunked_hidden_states,
+ (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim),
+ ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
+ chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
+
+ return chunked_hidden_states
+
+ @staticmethod
+ def _chunk(hidden_states, window_overlap):
+ """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
+ batch_size, seq_length, hidden_dim = shape_list(hidden_states)
+ num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1
+
+ # define frame size and frame stride (similar to convolution)
+ frame_hop_size = window_overlap * hidden_dim
+ frame_size = 2 * frame_hop_size
+ hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim))
+
+ # chunk with overlap
+ chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
+
+ tf.debugging.assert_equal(
+ shape_list(chunked_hidden_states),
+ [batch_size, num_output_chunks, frame_size],
+ message=(
+ "Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
+ f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
+ ),
+ )
+
+ chunked_hidden_states = tf.reshape(
+ chunked_hidden_states,
+ (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim),
+ )
+
+ return chunked_hidden_states
+
+ @staticmethod
+ def _get_global_attn_indices(is_index_global_attn):
+ """compute global attn indices required throughout forward pass"""
+ # helper variable
+ num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1)
+ num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype)
+
+ # max number of global attn indices in batch
+ max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices)
+
+ # indices of global attn
+ is_index_global_attn_nonzero = tf.where(is_index_global_attn)
+
+ # helper variable
+ is_local_index_global_attn = tf.range(max_num_global_attn_indices) < tf.expand_dims(
+ num_global_attn_indices, axis=-1
+ )
+
+ # location of the non-padding values within global attention indices
+ is_local_index_global_attn_nonzero = tf.where(is_local_index_global_attn)
+
+ # location of the padding values within global attention indices
+ is_local_index_no_global_attn_nonzero = tf.where(tf.math.logical_not(is_local_index_global_attn))
+
+ return (
+ max_num_global_attn_indices,
+ is_index_global_attn_nonzero,
+ is_local_index_global_attn_nonzero,
+ is_local_index_no_global_attn_nonzero,
+ )
+
+ def _concat_with_global_key_attn_probs(
+ self,
+ attn_scores,
+ key_vectors,
+ query_vectors,
+ max_num_global_attn_indices,
+ is_index_global_attn_nonzero,
+ is_local_index_global_attn_nonzero,
+ is_local_index_no_global_attn_nonzero,
+ ):
+ batch_size = shape_list(key_vectors)[0]
+
+ # select global key vectors
+ global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero)
+
+ # create only global key vectors
+ key_vectors_only_global = tf.scatter_nd(
+ is_local_index_global_attn_nonzero,
+ global_key_vectors,
+ shape=(
+ batch_size,
+ max_num_global_attn_indices,
+ self.num_heads,
+ self.head_dim,
+ ),
+ )
+
+ # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
+ attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global)
+
+ # (batch_size, max_num_global_attn_indices, seq_len, num_heads)
+ attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2))
+ mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
+ shape_list(attn_probs_from_global_key_trans)[-2:]
+ )
+ mask = tf.ones(mask_shape) * -10000.0
+ mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype)
+
+ # scatter mask
+ attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
+ attn_probs_from_global_key_trans,
+ is_local_index_no_global_attn_nonzero,
+ mask,
+ )
+
+ # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
+ attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans, (0, 2, 3, 1))
+
+ # concat to attn_probs
+ # (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
+ attn_scores = tf.concat((attn_probs_from_global_key, attn_scores), axis=-1)
+
+ return attn_scores
+
+ def _compute_attn_output_with_global_indices(
+ self,
+ value_vectors,
+ attn_probs,
+ max_num_global_attn_indices,
+ is_index_global_attn_nonzero,
+ is_local_index_global_attn_nonzero,
+ ):
+ batch_size = shape_list(attn_probs)[0]
+
+ # cut local attn probs to global only
+ attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices]
+
+ # select global value vectors
+ global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero)
+
+ # create only global value vectors
+ value_vectors_only_global = tf.scatter_nd(
+ is_local_index_global_attn_nonzero,
+ global_value_vectors,
+ shape=(
+ batch_size,
+ max_num_global_attn_indices,
+ self.num_heads,
+ self.head_dim,
+ ),
+ )
+
+ # compute attn output only global
+ attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global)
+
+ # reshape attn probs
+ attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:]
+
+ # compute attn output with global
+ attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
+ attn_probs_without_global, value_vectors, self.one_sided_attn_window_size
+ )
+
+ return attn_output_only_global + attn_output_without_global
+
+ def _compute_global_attn_output_from_hidden(
+ self,
+ attn_output,
+ hidden_states,
+ max_num_global_attn_indices,
+ layer_head_mask,
+ is_local_index_global_attn_nonzero,
+ is_index_global_attn_nonzero,
+ is_local_index_no_global_attn_nonzero,
+ is_index_masked,
+ training,
+ ):
+ batch_size, seq_len = shape_list(hidden_states)[:2]
+
+ # prepare global hidden states
+ global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero)
+ global_attn_hidden_states = tf.scatter_nd(
+ is_local_index_global_attn_nonzero,
+ global_attn_hidden_states,
+ shape=(batch_size, max_num_global_attn_indices, self.embed_dim),
+ )
+
+ # global key, query, value
+ global_query_vectors_only_global = self.query_global(global_attn_hidden_states)
+ global_key_vectors = self.key_global(hidden_states)
+ global_value_vectors = self.value_global(hidden_states)
+
+ # normalize
+ global_query_vectors_only_global /= tf.math.sqrt(
+ tf.cast(self.head_dim, dtype=global_query_vectors_only_global.dtype)
+ )
+ global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
+ global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
+ global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
+
+ # compute attn scores
+ global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
+
+ tf.debugging.assert_equal(
+ shape_list(global_attn_scores),
+ [batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
+ message=(
+ "global_attn_scores have the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
+ f" {shape_list(global_attn_scores)}."
+ ),
+ )
+
+ global_attn_scores = tf.reshape(
+ global_attn_scores,
+ (batch_size, self.num_heads, max_num_global_attn_indices, seq_len),
+ )
+ global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
+ mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
+ shape_list(global_attn_scores_trans)[-2:]
+ )
+ global_attn_mask = tf.ones(mask_shape) * -10000.0
+ global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype)
+
+ # scatter mask
+ global_attn_scores_trans = tf.tensor_scatter_nd_update(
+ global_attn_scores_trans,
+ is_local_index_no_global_attn_nonzero,
+ global_attn_mask,
+ )
+ global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
+
+ # mask global attn scores
+ attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1))
+ global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
+ global_attn_scores = tf.reshape(
+ global_attn_scores,
+ (batch_size * self.num_heads, max_num_global_attn_indices, seq_len),
+ )
+
+ # compute global attn probs
+ global_attn_probs_float = stable_softmax(global_attn_scores, axis=-1)
+
+ # apply layer head masking
+ if layer_head_mask is not None:
+ tf.debugging.assert_equal(
+ shape_list(layer_head_mask),
+ [self.num_heads],
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
+ )
+ global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
+ global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
+ )
+ global_attn_probs_float = tf.reshape(
+ global_attn_probs_float, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
+ )
+
+ # dropout
+ global_attn_probs = self.global_dropout(global_attn_probs_float, training=training)
+
+ # global attn output
+ global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
+
+ tf.debugging.assert_equal(
+ shape_list(global_attn_output),
+ [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
+ message=(
+ "global_attn_output tensor has the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
+ f" {shape_list(global_attn_output)}."
+ ),
+ )
+
+ global_attn_output = tf.reshape(
+ global_attn_output,
+ (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim),
+ )
+
+ # get only non zero global attn output
+ nonzero_global_attn_output = tf.gather_nd(
+ tf.transpose(global_attn_output, (0, 2, 1, 3)),
+ is_local_index_global_attn_nonzero,
+ )
+ nonzero_global_attn_output = tf.reshape(
+ nonzero_global_attn_output,
+ (shape_list(is_local_index_global_attn_nonzero)[0], -1),
+ )
+
+ # overwrite values with global attention
+ attn_output = tf.tensor_scatter_nd_update(
+ attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output
+ )
+
+ global_attn_probs = tf.reshape(
+ global_attn_probs, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
+ )
+
+ return attn_output, global_attn_probs
+
+ def reshape_and_transpose(self, vector, batch_size):
+ return tf.reshape(
+ tf.transpose(
+ tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)),
+ (0, 2, 1, 3),
+ ),
+ (batch_size * self.num_heads, -1, self.head_dim),
+ )
+
+
+class TFLongformerAttention(keras.layers.Layer):
+ def __init__(self, config, layer_id=0, **kwargs):
+ super().__init__(**kwargs)
+
+ self.self_attention = TFLongformerSelfAttention(config, layer_id, name="self")
+ self.dense_output = TFLongformerSelfOutput(config, name="output")
+
+ def prune_heads(self, heads):
+ raise NotImplementedError
+
+ def call(self, inputs, training=False):
+ (
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ is_index_masked,
+ is_index_global_attn,
+ is_global_attn,
+ ) = inputs
+
+ self_outputs = self.self_attention(
+ [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
+ training=training,
+ )
+ attention_output = self.dense_output(self_outputs[0], hidden_states, training=training)
+ outputs = (attention_output,) + self_outputs[1:]
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "self_attention", None) is not None:
+ with tf.name_scope(self.self_attention.name):
+ self.self_attention.build(None)
+ if getattr(self, "dense_output", None) is not None:
+ with tf.name_scope(self.dense_output.name):
+ self.dense_output.build(None)
+
+
+class TFLongformerLayer(keras.layers.Layer):
+ def __init__(self, config, layer_id=0, **kwargs):
+ super().__init__(**kwargs)
+
+ self.attention = TFLongformerAttention(config, layer_id, name="attention")
+ self.intermediate = TFLongformerIntermediate(config, name="intermediate")
+ self.longformer_output = TFLongformerOutput(config, name="output")
+
+ def call(self, inputs, training=False):
+ (
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ is_index_masked,
+ is_index_global_attn,
+ is_global_attn,
+ ) = inputs
+
+ attention_outputs = self.attention(
+ [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
+ training=training,
+ )
+ attention_output = attention_outputs[0]
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.longformer_output(intermediate_output, attention_output, training=training)
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "intermediate", None) is not None:
+ with tf.name_scope(self.intermediate.name):
+ self.intermediate.build(None)
+ if getattr(self, "longformer_output", None) is not None:
+ with tf.name_scope(self.longformer_output.name):
+ self.longformer_output.build(None)
+
+
+class TFLongformerEncoder(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.output_hidden_states = config.output_hidden_states
+ self.output_attentions = config.output_attentions
+ self.layer = [TFLongformerLayer(config, i, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
+
+ def call(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ padding_len=0,
+ is_index_masked=None,
+ is_index_global_attn=None,
+ is_global_attn=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ training=False,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = all_global_attentions = () if output_attentions else None
+
+ for idx, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
+ all_hidden_states = all_hidden_states + (hidden_states_to_add,)
+
+ layer_outputs = layer_module(
+ [
+ hidden_states,
+ attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ is_index_masked,
+ is_index_global_attn,
+ is_global_attn,
+ ],
+ training=training,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
+ all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
+
+ # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
+ all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),)
+
+ # Add last layer
+ if output_hidden_states:
+ hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
+ all_hidden_states = all_hidden_states + (hidden_states_to_add,)
+
+ # undo padding
+ # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
+ hidden_states = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
+ if output_attentions:
+ all_attentions = (
+ tuple(state[:, :, :-padding_len, :] for state in all_attentions) if padding_len > 0 else all_attentions
+ )
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None
+ )
+
+ return TFLongformerBaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ global_attentions=all_global_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layer", None) is not None:
+ for layer in self.layer:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+@keras_serializable
+class TFLongformerMainLayer(keras.layers.Layer):
+ config_class = LongformerConfig
+
+ def __init__(self, config, add_pooling_layer=True, **kwargs):
+ super().__init__(**kwargs)
+
+ if isinstance(config.attention_window, int):
+ assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value"
+ assert config.attention_window > 0, "`config.attention_window` has to be positive"
+ config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer
+ else:
+ assert len(config.attention_window) == config.num_hidden_layers, (
+ "`len(config.attention_window)` should equal `config.num_hidden_layers`. "
+ f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}"
+ )
+
+ self.config = config
+ self.num_hidden_layers = config.num_hidden_layers
+ self.initializer_range = config.initializer_range
+ self.output_attentions = config.output_attentions
+ self.output_hidden_states = config.output_hidden_states
+ self.return_dict = config.use_return_dict
+ self.pad_token_id = config.pad_token_id
+ self.attention_window = config.attention_window
+ self.embeddings = TFLongformerEmbeddings(config, name="embeddings")
+ self.encoder = TFLongformerEncoder(config, name="encoder")
+ self.pooler = TFLongformerPooler(config, name="pooler") if add_pooling_layer else None
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.weight = value
+ self.embeddings.vocab_size = shape_list(value)[0]
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ raise NotImplementedError
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ head_mask=None,
+ global_attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ training=False,
+ ):
+ if input_ids is not None and not isinstance(input_ids, tf.Tensor):
+ input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int64)
+ elif input_ids is not None:
+ input_ids = tf.cast(input_ids, tf.int64)
+
+ if attention_mask is not None and not isinstance(attention_mask, tf.Tensor):
+ attention_mask = tf.convert_to_tensor(attention_mask, dtype=tf.int64)
+ elif attention_mask is not None:
+ attention_mask = tf.cast(attention_mask, tf.int64)
+
+ if global_attention_mask is not None and not isinstance(global_attention_mask, tf.Tensor):
+ global_attention_mask = tf.convert_to_tensor(global_attention_mask, dtype=tf.int64)
+ elif global_attention_mask is not None:
+ global_attention_mask = tf.cast(global_attention_mask, tf.int64)
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if attention_mask is None:
+ attention_mask = tf.cast(tf.fill(input_shape, 1), tf.int64)
+
+ if token_type_ids is None:
+ token_type_ids = tf.cast(tf.fill(input_shape, 0), tf.int64)
+
+ # merge `global_attention_mask` and `attention_mask`
+ if global_attention_mask is not None:
+ attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
+
+ (
+ padding_len,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ inputs_embeds,
+ ) = self._pad_to_window_size(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ pad_token_id=self.pad_token_id,
+ )
+
+ # is index masked or global attention
+ is_index_masked = tf.math.less(attention_mask, 1)
+ is_index_global_attn = tf.math.greater(attention_mask, 1)
+ is_global_attn = tf.math.reduce_any(is_index_global_attn)
+
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, to_seq_length, 1, 1]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask_shape = shape_list(attention_mask)
+ extended_attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], attention_mask_shape[1], 1, 1))
+
+ # Since attention_mask is 1.0 for positions we want to attend locally and 0.0 for
+ # masked and global attn positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
+ embedding_output = self.embeddings(
+ input_ids,
+ position_ids,
+ token_type_ids,
+ inputs_embeds,
+ training=training,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ padding_len=padding_len,
+ is_index_masked=is_index_masked,
+ is_index_global_attn=is_index_global_attn,
+ is_global_attn=is_global_attn,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (
+ sequence_output,
+ pooled_output,
+ ) + encoder_outputs[1:]
+
+ return TFLongformerBaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ global_attentions=encoder_outputs.global_attentions,
+ )
+
+ def _pad_to_window_size(
+ self,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ inputs_embeds,
+ pad_token_id,
+ ):
+ """A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
+ # padding
+ attention_window = (
+ self.attention_window if isinstance(self.attention_window, int) else max(self.attention_window)
+ )
+
+ assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
+
+ input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)
+ batch_size, seq_len = input_shape[:2]
+ padding_len = (attention_window - seq_len % attention_window) % attention_window
+
+ paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])
+
+ if input_ids is not None:
+ input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
+
+ if position_ids is not None:
+ # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings
+ position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)
+
+ if inputs_embeds is not None:
+ if padding_len > 0:
+ input_ids_padding = tf.cast(tf.fill((batch_size, padding_len), self.pad_token_id), tf.int64)
+ inputs_embeds_padding = self.embeddings(input_ids_padding)
+ inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
+
+ attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
+ token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
+
+ return (
+ padding_len,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ inputs_embeds,
+ )
+
+ @staticmethod
+ def _merge_to_attention_mask(attention_mask: tf.Tensor, global_attention_mask: tf.Tensor):
+ # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
+ # (global_attention_mask + 1) => 1 for local attention, 2 for global attention
+ # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
+ if attention_mask is not None:
+ attention_mask = attention_mask * (global_attention_mask + 1)
+ else:
+ # simply use `global_attention_mask` as `attention_mask`
+ # if no `attention_mask` is given
+ attention_mask = global_attention_mask + 1
+
+ return attention_mask
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "pooler", None) is not None:
+ with tf.name_scope(self.pooler.name):
+ self.pooler.build(None)
+
+
+class TFLongformerPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = LongformerConfig
+ base_model_prefix = "longformer"
+
+ @property
+ def input_signature(self):
+ sig = super().input_signature
+ sig["global_attention_mask"] = tf.TensorSpec((None, None), tf.int32, name="global_attention_mask")
+ return sig
+
+
+LONGFORMER_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`LongformerConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+LONGFORMER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`np.ndarray` or `tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ global_attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Mask to decide the attention given on each token, local attention or global attention. Tokens with global
+ attention attends to all other tokens, and all other tokens attend to them. This is important for
+ task-specific finetuning because it makes the model more flexible at representing the task. For example,
+ for classification, the token should be given global attention. For QA, all question tokens should also
+ have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more
+ details. Mask values selected in `[0, 1]`:
+
+ - 0 for local attention (a sliding window attention),
+ - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
+
+ token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare Longformer Model outputting raw hidden-states without any specific head on top.",
+ LONGFORMER_START_DOCSTRING,
+)
+class TFLongformerModel(TFLongformerPreTrainedModel):
+ """
+
+ This class copies code from [`TFRobertaModel`] and overwrites standard self-attention with longformer
+ self-attention to provide the ability to process long sequences following the self-attention approach described in
+ [Longformer: the Long-Document Transformer](https://huggingface.co/papers/2004.05150) by Iz Beltagy, Matthew E. Peters, and
+ Arman Cohan. Longformer self-attention combines a local (sliding window) and global attention to extend to long
+ documents without the O(n^2) increase in memory and compute.
+
+ The self-attention module `TFLongformerSelfAttention` implemented here supports the combination of local and global
+ attention but it lacks support for autoregressive attention and dilated attention. Autoregressive and dilated
+ attention are more relevant for autoregressive language modeling than finetuning on downstream tasks. Future
+ release will add support for autoregressive attention, but the support for dilated attention requires a custom CUDA
+ kernel to be memory and compute efficient.
+
+ """
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.longformer = TFLongformerMainLayer(config, name="longformer")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ global_attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> TFLongformerBaseModelOutputWithPooling | tuple[tf.Tensor]:
+ outputs = self.longformer(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ global_attention_mask=global_attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "longformer", None) is not None:
+ with tf.name_scope(self.longformer.name):
+ self.longformer.build(None)
+
+
+@add_start_docstrings(
+ """Longformer Model with a `language modeling` head on top.""",
+ LONGFORMER_START_DOCSTRING,
+)
+class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
+ self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head")
+
+ def get_lm_head(self):
+ return self.lm_head
+
+ def get_prefix_bias_name(self):
+ warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
+ return self.name + "/" + self.lm_head.name
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="allenai/longformer-base-4096",
+ output_type=TFLongformerMaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ mask="",
+ expected_output="' Paris'",
+ expected_loss=0.44,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ global_attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFLongformerMaskedLMOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+
+ outputs = self.longformer(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ global_attention_mask=global_attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ prediction_scores = self.lm_head(sequence_output, training=training)
+ loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFLongformerMaskedLMOutput(
+ loss=loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ global_attentions=outputs.global_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "longformer", None) is not None:
+ with tf.name_scope(self.longformer.name):
+ self.longformer.build(None)
+ if getattr(self, "lm_head", None) is not None:
+ with tf.name_scope(self.lm_head.name):
+ self.lm_head.build(None)
+
+
+@add_start_docstrings(
+ """
+ Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD /
+ TriviaQA (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ LONGFORMER_START_DOCSTRING,
+)
+class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+ self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
+ self.qa_outputs = keras.layers.Dense(
+ config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="qa_outputs",
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="allenai/longformer-large-4096-finetuned-triviaqa",
+ output_type=TFLongformerQuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output="' puppet'",
+ expected_loss=0.96,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ global_attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ start_positions: np.ndarray | tf.Tensor | None = None,
+ end_positions: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFLongformerQuestionAnsweringModelOutput | tuple[tf.Tensor]:
+ r"""
+ start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+
+ if input_ids is not None and not isinstance(input_ids, tf.Tensor):
+ input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int64)
+ elif input_ids is not None:
+ input_ids = tf.cast(input_ids, tf.int64)
+
+ if attention_mask is not None and not isinstance(attention_mask, tf.Tensor):
+ attention_mask = tf.convert_to_tensor(attention_mask, dtype=tf.int64)
+ elif attention_mask is not None:
+ attention_mask = tf.cast(attention_mask, tf.int64)
+
+ if global_attention_mask is not None and not isinstance(global_attention_mask, tf.Tensor):
+ global_attention_mask = tf.convert_to_tensor(global_attention_mask, dtype=tf.int64)
+ elif global_attention_mask is not None:
+ global_attention_mask = tf.cast(global_attention_mask, tf.int64)
+
+ # set global attention on question tokens
+ if global_attention_mask is None and input_ids is not None:
+ if shape_list(tf.where(input_ids == self.config.sep_token_id))[0] != 3 * shape_list(input_ids)[0]:
+ logger.warning(
+ f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for"
+ " questions answering. You might also consider to set `global_attention_mask` manually in the"
+ " forward function to avoid this. This is most likely an error. The global attention is disabled"
+ " for this forward pass."
+ )
+ global_attention_mask = tf.cast(tf.fill(shape_list(input_ids), value=0), tf.int64)
+ else:
+ logger.warning_once("Initializing global attention on question tokens...")
+ # put global attention on all tokens until `config.sep_token_id` is reached
+ sep_token_indices = tf.where(input_ids == self.config.sep_token_id)
+ sep_token_indices = tf.cast(sep_token_indices, dtype=tf.int64)
+ global_attention_mask = _compute_global_attention_mask(shape_list(input_ids), sep_token_indices)
+
+ outputs = self.longformer(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ global_attention_mask=global_attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = tf.split(logits, 2, axis=-1)
+ start_logits = tf.squeeze(start_logits, axis=-1)
+ end_logits = tf.squeeze(end_logits, axis=-1)
+ loss = None
+
+ if start_positions is not None and end_positions is not None:
+ labels = {"start_position": start_positions}
+ labels["end_position"] = end_positions
+ loss = self.hf_compute_loss(labels, (start_logits, end_logits))
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFLongformerQuestionAnsweringModelOutput(
+ loss=loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ global_attentions=outputs.global_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "longformer", None) is not None:
+ with tf.name_scope(self.longformer.name):
+ self.longformer.build(None)
+ if getattr(self, "qa_outputs", None) is not None:
+ with tf.name_scope(self.qa_outputs.name):
+ self.qa_outputs.build([None, None, self.config.hidden_size])
+
+
+class TFLongformerClassificationHead(keras.layers.Layer):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ self.dense = keras.layers.Dense(
+ config.hidden_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation="tanh",
+ name="dense",
+ )
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+ self.out_proj = keras.layers.Dense(
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
+ )
+ self.config = config
+
+ def call(self, hidden_states, training=False):
+ hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS])
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ output = self.out_proj(hidden_states)
+ return output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "out_proj", None) is not None:
+ with tf.name_scope(self.out_proj.name):
+ self.out_proj.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ Longformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ LONGFORMER_START_DOCSTRING,
+)
+class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+
+ self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
+ self.classifier = TFLongformerClassificationHead(config, name="classifier")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFLongformerSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ global_attention_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFLongformerSequenceClassifierOutput | tuple[tf.Tensor]:
+ if input_ids is not None and not isinstance(input_ids, tf.Tensor):
+ input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int64)
+ elif input_ids is not None:
+ input_ids = tf.cast(input_ids, tf.int64)
+
+ if attention_mask is not None and not isinstance(attention_mask, tf.Tensor):
+ attention_mask = tf.convert_to_tensor(attention_mask, dtype=tf.int64)
+ elif attention_mask is not None:
+ attention_mask = tf.cast(attention_mask, tf.int64)
+
+ if global_attention_mask is not None and not isinstance(global_attention_mask, tf.Tensor):
+ global_attention_mask = tf.convert_to_tensor(global_attention_mask, dtype=tf.int64)
+ elif global_attention_mask is not None:
+ global_attention_mask = tf.cast(global_attention_mask, tf.int64)
+
+ if global_attention_mask is None and input_ids is not None:
+ logger.warning_once("Initializing global attention on CLS token...")
+ # global attention on cls token
+ global_attention_mask = tf.zeros_like(input_ids)
+ updates = tf.ones(shape_list(input_ids)[0], dtype=tf.int64)
+ indices = tf.pad(
+ tensor=tf.expand_dims(tf.range(shape_list(input_ids)[0], dtype=tf.int64), axis=1),
+ paddings=[[0, 0], [0, 1]],
+ constant_values=0,
+ )
+ global_attention_mask = tf.tensor_scatter_nd_update(
+ global_attention_mask,
+ indices,
+ updates,
+ )
+
+ outputs = self.longformer(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ global_attention_mask=global_attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ logits = self.classifier(sequence_output)
+
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFLongformerSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ global_attentions=outputs.global_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "longformer", None) is not None:
+ with tf.name_scope(self.longformer.name):
+ self.longformer.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build(None)
+
+
+@add_start_docstrings(
+ """
+ Longformer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
+ a softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ LONGFORMER_START_DOCSTRING,
+)
+class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoiceLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_missing = [r"dropout"]
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.longformer = TFLongformerMainLayer(config, name="longformer")
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+ self.classifier = keras.layers.Dense(
+ 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
+ )
+ self.config = config
+
+ @property
+ def input_signature(self):
+ return {
+ "input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
+ "attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
+ "global_attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="global_attention_mask"),
+ }
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(
+ LONGFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
+ )
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFLongformerMultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ global_attention_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFLongformerMultipleChoiceModelOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
+ where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
+ """
+
+ if input_ids is not None:
+ num_choices = shape_list(input_ids)[1]
+ seq_length = shape_list(input_ids)[2]
+ else:
+ num_choices = shape_list(inputs_embeds)[1]
+ seq_length = shape_list(inputs_embeds)[2]
+
+ flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
+ flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
+ flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
+ flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
+ flat_global_attention_mask = (
+ tf.reshape(global_attention_mask, (-1, shape_list(global_attention_mask)[-1]))
+ if global_attention_mask is not None
+ else None
+ )
+ flat_inputs_embeds = (
+ tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.longformer(
+ flat_input_ids,
+ position_ids=flat_position_ids,
+ token_type_ids=flat_token_type_ids,
+ attention_mask=flat_attention_mask,
+ head_mask=head_mask,
+ global_attention_mask=flat_global_attention_mask,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = tf.reshape(logits, (-1, num_choices))
+
+ loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFLongformerMultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ global_attentions=outputs.global_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "longformer", None) is not None:
+ with tf.name_scope(self.longformer.name):
+ self.longformer.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ Longformer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
+ for Named-Entity-Recognition (NER) tasks.
+ """,
+ LONGFORMER_START_DOCSTRING,
+)
+class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"dropout"]
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+ self.longformer = TFLongformerMainLayer(config=config, add_pooling_layer=False, name="longformer")
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+ self.classifier = keras.layers.Dense(
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFLongformerTokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ global_attention_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.array | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFLongformerTokenClassifierOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+
+ outputs = self.longformer(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ global_attention_mask=global_attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFLongformerTokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ global_attentions=outputs.global_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "longformer", None) is not None:
+ with tf.name_scope(self.longformer.name):
+ self.longformer.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+__all__ = [
+ "TFLongformerForMaskedLM",
+ "TFLongformerForMultipleChoice",
+ "TFLongformerForQuestionAnswering",
+ "TFLongformerForSequenceClassification",
+ "TFLongformerForTokenClassification",
+ "TFLongformerModel",
+ "TFLongformerPreTrainedModel",
+ "TFLongformerSelfAttention",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/tokenization_longformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/tokenization_longformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..104bdd7a9b99f5a4809557dcee97cda4873cea12
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/tokenization_longformer.py
@@ -0,0 +1,402 @@
+# coding=utf-8
+# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from functools import lru_cache
+from typing import Optional
+
+import regex as re
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
+
+
+@lru_cache
+# Copied from transformers.models.roberta.tokenization_roberta.bytes_to_unicode
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+ characters the bpe code barfs on.
+
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
+ tables between utf-8 bytes and unicode strings.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+# Copied from transformers.models.roberta.tokenization_roberta.get_pairs
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+# Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer with FacebookAI/roberta-base->allenai/longformer-base-4096, RoBERTa->Longformer all-casing, RobertaTokenizer->LongformerTokenizer
+class LongformerTokenizer(PreTrainedTokenizer):
+ """
+ Constructs a Longformer tokenizer, derived from the GPT-2 tokenizer, using byte-level Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```python
+ >>> from transformers import LongformerTokenizer
+
+ >>> tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
+ >>> tokenizer("Hello world")["input_ids"]
+ [0, 31414, 232, 2]
+
+ >>> tokenizer(" Hello world")["input_ids"]
+ [0, 20920, 232, 2]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
+ call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (Longformer tokenizer detect beginning of words by the preceding space).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ errors="replace",
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ add_prefix_space=False,
+ **kwargs,
+ ):
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
+ sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
+ cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
+
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = (
+ AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
+ if isinstance(mask_token, str)
+ else mask_token
+ )
+
+ # these special tokens are not part of the vocab.json, let's add them in the correct order
+
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ bpe_merges = merges_handle.read().split("\n")[1:-1]
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {}
+ self.add_prefix_space = add_prefix_space
+
+ # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+
+ super().__init__(
+ errors=errors,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ add_prefix_space=add_prefix_space,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self):
+ return len(self.encoder)
+
+ def get_vocab(self):
+ vocab = dict(self.encoder).copy()
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def _tokenize(self, text):
+ """Tokenize a string."""
+ bpe_tokens = []
+ for token in re.findall(self.pat, text):
+ token = "".join(
+ self.byte_encoder[b] for b in token.encode("utf-8")
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ text = "".join(tokens)
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
+ return text
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ merge_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+ )
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write("#version: 0.2\n")
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!"
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return vocab_file, merge_file
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A Longformer sequence has the following format:
+
+ - single sequence: ` X `
+ - pair of sequences: ` A B `
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. Longformer does not
+ make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
+ add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
+ if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
+ text = " " + text
+ return (text, kwargs)
+
+
+__all__ = ["LongformerTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/tokenization_longformer_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/tokenization_longformer_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..bde6bb55fec609ae3550c5cf01b43f0dd3bc866c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/longformer/tokenization_longformer_fast.py
@@ -0,0 +1,265 @@
+# coding=utf-8
+# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Tokenization classes for Longformer."""
+
+import json
+from typing import Optional
+
+from tokenizers import processors
+
+from ...tokenization_utils_base import AddedToken, BatchEncoding
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+from .tokenization_longformer import LongformerTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
+
+
+# Copied from transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast with FacebookAI/roberta-base->allenai/longformer-base-4096, RoBERTa->Longformer all-casing, Roberta->Longformer
+class LongformerTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" Longformer tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2
+ tokenizer, using byte-level Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```python
+ >>> from transformers import LongformerTokenizerFast
+
+ >>> tokenizer = LongformerTokenizerFast.from_pretrained("allenai/longformer-base-4096")
+ >>> tokenizer("Hello world")["input_ids"]
+ [0, 31414, 232, 2]
+
+ >>> tokenizer(" Hello world")["input_ids"]
+ [0, 20920, 232, 2]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
+ call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (Longformer tokenizer detect beginning of words by the preceding space).
+ trim_offsets (`bool`, *optional*, defaults to `True`):
+ Whether the post processing step should trim offsets to avoid including whitespaces.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = LongformerTokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ merges_file=None,
+ tokenizer_file=None,
+ errors="replace",
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ add_prefix_space=False,
+ trim_offsets=True,
+ **kwargs,
+ ):
+ mask_token = (
+ AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
+ if isinstance(mask_token, str)
+ else mask_token
+ )
+ super().__init__(
+ vocab_file,
+ merges_file,
+ tokenizer_file=tokenizer_file,
+ errors=errors,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ add_prefix_space=add_prefix_space,
+ trim_offsets=trim_offsets,
+ **kwargs,
+ )
+
+ tokenizer_component = "post_processor"
+ tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
+ if tokenizer_component_instance:
+ state = json.loads(tokenizer_component_instance.__getstate__())
+
+ # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class`
+ if "sep" in state:
+ state["sep"] = tuple(state["sep"])
+ if "cls" in state:
+ state["cls"] = tuple(state["cls"])
+
+ changes_to_apply = False
+
+ if state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
+ state["add_prefix_space"] = add_prefix_space
+ changes_to_apply = True
+
+ if state.get("trim_offsets", trim_offsets) != trim_offsets:
+ state["trim_offsets"] = trim_offsets
+ changes_to_apply = True
+
+ if changes_to_apply:
+ component_class = getattr(processors, state.pop("type"))
+ new_value = component_class(**state)
+ setattr(self.backend_tokenizer, tokenizer_component, new_value)
+
+ @property
+ def mask_token(self) -> str:
+ """
+ `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not
+ having been set.
+
+ Longformer tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily
+ comprise the space before the **.
+ """
+ if self._mask_token is None:
+ if self.verbose:
+ logger.error("Using mask_token, but it is not set yet.")
+ return None
+ return str(self._mask_token)
+
+ @mask_token.setter
+ def mask_token(self, value):
+ """
+ Overriding the default behavior of the mask token to have it eat the space before it.
+
+ This is needed to preserve backward compatibility with all the previously used models based on Longformer.
+ """
+ # Mask token behave like a normal word, i.e. include the space before it
+ # So we set lstrip to True
+ value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
+ self._mask_token = value
+
+ def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+ assert self.add_prefix_space or not is_split_into_words, (
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
+ "to use it with pretokenized inputs."
+ )
+
+ return super()._batch_encode_plus(*args, **kwargs)
+
+ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+
+ assert self.add_prefix_space or not is_split_into_words, (
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
+ "to use it with pretokenized inputs."
+ )
+
+ return super()._encode_plus(*args, **kwargs)
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
+ if token_ids_1 is None:
+ return output
+
+ return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. Longformer does not
+ make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+
+__all__ = ["LongformerTokenizerFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a91c136c2f32c434a52161342fb8df6f7b0011d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_maskformer import *
+ from .configuration_maskformer_swin import *
+ from .feature_extraction_maskformer import *
+ from .image_processing_maskformer import *
+ from .image_processing_maskformer_fast import *
+ from .modeling_maskformer import *
+ from .modeling_maskformer_swin import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/configuration_maskformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/configuration_maskformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d988acb45e95fae720c178165c911dc410ad1748
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/configuration_maskformer.py
@@ -0,0 +1,235 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc.and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MaskFormer model configuration"""
+
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import verify_backbone_config_arguments
+from ..auto import CONFIG_MAPPING
+from ..detr import DetrConfig
+from ..swin import SwinConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class MaskFormerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MaskFormerModel`]. It is used to instantiate a
+ MaskFormer model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the MaskFormer
+ [facebook/maskformer-swin-base-ade](https://huggingface.co/facebook/maskformer-swin-base-ade) architecture trained
+ on [ADE20k-150](https://huggingface.co/datasets/scene_parse_150).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Currently, MaskFormer only supports the [Swin Transformer](swin) as backbone.
+
+ Args:
+ mask_feature_size (`int`, *optional*, defaults to 256):
+ The masks' features size, this value will also be used to specify the Feature Pyramid Network features'
+ size.
+ no_object_weight (`float`, *optional*, defaults to 0.1):
+ Weight to apply to the null (no object) class.
+ use_auxiliary_loss(`bool`, *optional*, defaults to `False`):
+ If `True` [`MaskFormerForInstanceSegmentationOutput`] will contain the auxiliary losses computed using the
+ logits from each decoder's stage.
+ backbone_config (`Dict`, *optional*):
+ The configuration passed to the backbone, if unset, the configuration corresponding to
+ `swin-base-patch4-window12-384` will be used.
+ backbone (`str`, *optional*):
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
+ use_pretrained_backbone (`bool`, *optional*, `False`):
+ Whether to use pretrained weights for the backbone.
+ use_timm_backbone (`bool`, *optional*, `False`):
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
+ library.
+ backbone_kwargs (`dict`, *optional*):
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
+ decoder_config (`Dict`, *optional*):
+ The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50`
+ will be used.
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ init_xavier_std (`float`, *optional*, defaults to 1):
+ The scaling factor used for the Xavier initialization gain in the HM Attention map module.
+ dice_weight (`float`, *optional*, defaults to 1.0):
+ The weight for the dice loss.
+ cross_entropy_weight (`float`, *optional*, defaults to 1.0):
+ The weight for the cross entropy loss.
+ mask_weight (`float`, *optional*, defaults to 20.0):
+ The weight for the mask loss.
+ output_auxiliary_logits (`bool`, *optional*):
+ Should the model output its `auxiliary_logits` or not.
+
+ Raises:
+ `ValueError`:
+ Raised if the backbone model type selected is not in `["swin"]` or the decoder model type selected is not
+ in `["detr"]`
+
+ Examples:
+
+ ```python
+ >>> from transformers import MaskFormerConfig, MaskFormerModel
+
+ >>> # Initializing a MaskFormer facebook/maskformer-swin-base-ade configuration
+ >>> configuration = MaskFormerConfig()
+
+ >>> # Initializing a model (with random weights) from the facebook/maskformer-swin-base-ade style configuration
+ >>> model = MaskFormerModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+
+ """
+
+ model_type = "maskformer"
+ attribute_map = {"hidden_size": "mask_feature_size"}
+ backbones_supported = ["resnet", "swin"]
+ decoders_supported = ["detr"]
+
+ def __init__(
+ self,
+ fpn_feature_size: int = 256,
+ mask_feature_size: int = 256,
+ no_object_weight: float = 0.1,
+ use_auxiliary_loss: bool = False,
+ backbone_config: Optional[dict] = None,
+ decoder_config: Optional[dict] = None,
+ init_std: float = 0.02,
+ init_xavier_std: float = 1.0,
+ dice_weight: float = 1.0,
+ cross_entropy_weight: float = 1.0,
+ mask_weight: float = 20.0,
+ output_auxiliary_logits: Optional[bool] = None,
+ backbone: Optional[str] = None,
+ use_pretrained_backbone: bool = False,
+ use_timm_backbone: bool = False,
+ backbone_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
+ if backbone_config is None and backbone is None:
+ # fall back to https://huggingface.co/microsoft/swin-base-patch4-window12-384-in22k
+ backbone_config = SwinConfig(
+ image_size=384,
+ num_channels=3,
+ patch_size=4,
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=12,
+ drop_path_rate=0.3,
+ out_features=["stage1", "stage2", "stage3", "stage4"],
+ )
+ elif isinstance(backbone_config, dict):
+ backbone_model_type = backbone_config.pop("model_type")
+ config_class = CONFIG_MAPPING[backbone_model_type]
+ backbone_config = config_class.from_dict(backbone_config)
+
+ verify_backbone_config_arguments(
+ use_timm_backbone=use_timm_backbone,
+ use_pretrained_backbone=use_pretrained_backbone,
+ backbone=backbone,
+ backbone_config=backbone_config,
+ backbone_kwargs=backbone_kwargs,
+ )
+ # verify that the backbone is supported
+ if backbone_config is not None and backbone_config.model_type not in self.backbones_supported:
+ logger.warning_once(
+ f"Backbone {backbone_config.model_type} is not a supported model and may not be compatible with MaskFormer. "
+ f"Supported model types: {','.join(self.backbones_supported)}"
+ )
+
+ if decoder_config is None:
+ # fall back to https://huggingface.co/facebook/detr-resnet-50
+ decoder_config = DetrConfig()
+ else:
+ # verify that the decoder is supported
+ decoder_type = (
+ decoder_config.pop("model_type") if isinstance(decoder_config, dict) else decoder_config.model_type
+ )
+ if decoder_type not in self.decoders_supported:
+ raise ValueError(
+ f"Transformer Decoder {decoder_type} not supported, please use one of"
+ f" {','.join(self.decoders_supported)}"
+ )
+ if isinstance(decoder_config, dict):
+ config_class = CONFIG_MAPPING[decoder_type]
+ decoder_config = config_class.from_dict(decoder_config)
+
+ self.backbone_config = backbone_config
+ self.decoder_config = decoder_config
+ # main feature dimension for the model
+ self.fpn_feature_size = fpn_feature_size
+ self.mask_feature_size = mask_feature_size
+ # initializer
+ self.init_std = init_std
+ self.init_xavier_std = init_xavier_std
+ # Hungarian matcher && loss
+ self.cross_entropy_weight = cross_entropy_weight
+ self.dice_weight = dice_weight
+ self.mask_weight = mask_weight
+ self.use_auxiliary_loss = use_auxiliary_loss
+ self.no_object_weight = no_object_weight
+ self.output_auxiliary_logits = output_auxiliary_logits
+
+ self.num_attention_heads = self.decoder_config.encoder_attention_heads
+ self.num_hidden_layers = self.decoder_config.num_hidden_layers
+ self.backbone = backbone
+ self.use_pretrained_backbone = use_pretrained_backbone
+ self.use_timm_backbone = use_timm_backbone
+ self.backbone_kwargs = backbone_kwargs
+ super().__init__(**kwargs)
+
+ @property
+ def sub_configs(self):
+ sub_configs = {}
+ if self.backbone_config is not None and self.backbone_config != {}:
+ sub_configs["backbone_config"] = type(self.backbone_config)
+ if self.decoder_config is not None and self.decoder_config != {}:
+ sub_configs["decoder_config"] = type(self.decoder_config)
+ return sub_configs
+
+ @classmethod
+ def from_backbone_and_decoder_configs(
+ cls, backbone_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
+ ):
+ """Instantiate a [`MaskFormerConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
+ configuration.
+
+ Args:
+ backbone_config ([`PretrainedConfig`]):
+ The backbone configuration.
+ decoder_config ([`PretrainedConfig`]):
+ The transformer decoder configuration to use.
+
+ Returns:
+ [`MaskFormerConfig`]: An instance of a configuration object
+ """
+ return cls(
+ backbone_config=backbone_config,
+ decoder_config=decoder_config,
+ **kwargs,
+ )
+
+
+__all__ = ["MaskFormerConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/configuration_maskformer_swin.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/configuration_maskformer_swin.py
new file mode 100644
index 0000000000000000000000000000000000000000..84157117bbf2a388fb7dbd1b4aa227e1fa016513
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/configuration_maskformer_swin.py
@@ -0,0 +1,153 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MaskFormer Swin Transformer model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class MaskFormerSwinConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MaskFormerSwinModel`]. It is used to instantiate
+ a Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Swin
+ [microsoft/swin-tiny-patch4-window7-224](https://huggingface.co/microsoft/swin-tiny-patch4-window7-224)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 4):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ embed_dim (`int`, *optional*, defaults to 96):
+ Dimensionality of patch embedding.
+ depths (`list[int]`, *optional*, defaults to `[2, 2, 6, 2]`):
+ Depth of each layer in the Transformer encoder.
+ num_heads (`list[int]`, *optional*, defaults to `[3, 6, 12, 24]`):
+ Number of attention heads in each layer of the Transformer encoder.
+ window_size (`int`, *optional*, defaults to 7):
+ Size of windows.
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ Ratio of MLP hidden dimensionality to embedding dimensionality.
+ qkv_bias (`bool`, *optional*, defaults to True):
+ Whether or not a learnable bias should be added to the queries, keys and values.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings and encoder.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
+ Stochastic depth rate.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
+ `"selu"` and `"gelu_new"` are supported.
+ use_absolute_embeddings (`bool`, *optional*, defaults to False):
+ Whether or not to add absolute position embeddings to the patch embeddings.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ out_features (`list[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`list[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+
+ Example:
+
+ ```python
+ >>> from transformers import MaskFormerSwinConfig, MaskFormerSwinModel
+
+ >>> # Initializing a microsoft/swin-tiny-patch4-window7-224 style configuration
+ >>> configuration = MaskFormerSwinConfig()
+
+ >>> # Initializing a model (with random weights) from the microsoft/swin-tiny-patch4-window7-224 style configuration
+ >>> model = MaskFormerSwinModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "maskformer-swin"
+
+ attribute_map = {
+ "num_attention_heads": "num_heads",
+ "num_hidden_layers": "num_layers",
+ }
+
+ def __init__(
+ self,
+ image_size=224,
+ patch_size=4,
+ num_channels=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.1,
+ hidden_act="gelu",
+ use_absolute_embeddings=False,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ out_features=None,
+ out_indices=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.num_layers = len(depths)
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.drop_path_rate = drop_path_rate
+ self.hidden_act = hidden_act
+ self.use_absolute_embeddings = use_absolute_embeddings
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
+ # this indicates the channel dimension after the last stage of the model
+ self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+
+
+__all__ = ["MaskFormerSwinConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/feature_extraction_maskformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/feature_extraction_maskformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..98f7075fab83882bb84ab955bd763c2f1c1f067b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/feature_extraction_maskformer.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for MaskFormer."""
+
+import warnings
+
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_maskformer import MaskFormerImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class MaskFormerFeatureExtractor(MaskFormerImageProcessor):
+ def __init__(self, *args, **kwargs) -> None:
+ warnings.warn(
+ "The class MaskFormerFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
+ " Please use MaskFormerImageProcessor instead.",
+ FutureWarning,
+ )
+ super().__init__(*args, **kwargs)
+
+
+__all__ = ["MaskFormerFeatureExtractor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/image_processing_maskformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/image_processing_maskformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2f9aee70167fe956e1f13028c4b6daa44e91ac4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/image_processing_maskformer.py
@@ -0,0 +1,1323 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for MaskFormer."""
+
+import math
+import warnings
+from collections.abc import Iterable
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ PaddingMode,
+ get_resize_output_image_size,
+ pad,
+ rescale,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ TensorType,
+ filter_out_non_signature_kwargs,
+ is_torch_available,
+ is_torch_tensor,
+ logging,
+)
+from ...utils.import_utils import requires
+
+
+logger = logging.get_logger(__name__)
+
+
+if TYPE_CHECKING:
+ from transformers import MaskFormerForInstanceSegmentationOutput
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+
+# Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio
+def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]:
+ """
+ Computes the output image size given the input image size and the desired output size.
+
+ Args:
+ image_size (`tuple[int, int]`):
+ The input image size.
+ size (`int`):
+ The desired output size.
+ max_size (`int`, *optional*):
+ The maximum allowed output size.
+ """
+ height, width = image_size
+ raw_size = None
+ if max_size is not None:
+ min_original_size = float(min((height, width)))
+ max_original_size = float(max((height, width)))
+ if max_original_size / min_original_size * size > max_size:
+ raw_size = max_size * min_original_size / max_original_size
+ size = int(round(raw_size))
+
+ if (height <= width and height == size) or (width <= height and width == size):
+ oh, ow = height, width
+ elif width < height:
+ ow = size
+ if max_size is not None and raw_size is not None:
+ oh = int(raw_size * height / width)
+ else:
+ oh = int(size * height / width)
+ else:
+ oh = size
+ if max_size is not None and raw_size is not None:
+ ow = int(raw_size * width / height)
+ else:
+ ow = int(size * width / height)
+
+ return (oh, ow)
+
+
+# Copied from transformers.models.detr.image_processing_detr.max_across_indices
+def max_across_indices(values: Iterable[Any]) -> list[Any]:
+ """
+ Return the maximum value across all indices of an iterable of values.
+ """
+ return [max(values_i) for values_i in zip(*values)]
+
+
+# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
+def get_max_height_width(
+ images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> list[int]:
+ """
+ Get the maximum height and width across all images in a batch.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if input_data_format == ChannelDimension.FIRST:
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
+ elif input_data_format == ChannelDimension.LAST:
+ max_height, max_width, _ = max_across_indices([img.shape for img in images])
+ else:
+ raise ValueError(f"Invalid channel dimension format: {input_data_format}")
+ return (max_height, max_width)
+
+
+# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
+def make_pixel_mask(
+ image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> np.ndarray:
+ """
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+
+ Args:
+ image (`np.ndarray`):
+ Image to make the pixel mask for.
+ output_size (`tuple[int, int]`):
+ Output size of the mask.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ mask = np.zeros(output_size, dtype=np.int64)
+ mask[:input_height, :input_width] = 1
+ return mask
+
+
+# Copied from transformers.models.detr.image_processing_detr.binary_mask_to_rle
+def binary_mask_to_rle(mask):
+ """
+ Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format.
+
+ Args:
+ mask (`torch.Tensor` or `numpy.array`):
+ A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
+ segment_id or class_id.
+ Returns:
+ `List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
+ format.
+ """
+ if is_torch_tensor(mask):
+ mask = mask.numpy()
+
+ pixels = mask.flatten()
+ pixels = np.concatenate([[0], pixels, [0]])
+ runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
+ runs[1::2] -= runs[::2]
+ return list(runs)
+
+
+# Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle
+def convert_segmentation_to_rle(segmentation):
+ """
+ Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format.
+
+ Args:
+ segmentation (`torch.Tensor` or `numpy.array`):
+ A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
+ Returns:
+ `list[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
+ """
+ segment_ids = torch.unique(segmentation)
+
+ run_length_encodings = []
+ for idx in segment_ids:
+ mask = torch.where(segmentation == idx, 1, 0)
+ rle = binary_mask_to_rle(mask)
+ run_length_encodings.append(rle)
+
+ return run_length_encodings
+
+
+# Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects
+def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
+ """
+ Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
+ `labels`.
+
+ Args:
+ masks (`torch.Tensor`):
+ A tensor of shape `(num_queries, height, width)`.
+ scores (`torch.Tensor`):
+ A tensor of shape `(num_queries)`.
+ labels (`torch.Tensor`):
+ A tensor of shape `(num_queries)`.
+ object_mask_threshold (`float`):
+ A number between 0 and 1 used to binarize the masks.
+ Raises:
+ `ValueError`: Raised when the first dimension doesn't match in all input tensors.
+ Returns:
+ `tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
+ < `object_mask_threshold`.
+ """
+ if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
+ raise ValueError("mask, scores and labels must have the same shape!")
+
+ to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
+
+ return masks[to_keep], scores[to_keep], labels[to_keep]
+
+
+# Copied from transformers.models.detr.image_processing_detr.check_segment_validity
+def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
+ # Get the mask associated with the k class
+ mask_k = mask_labels == k
+ mask_k_area = mask_k.sum()
+
+ # Compute the area of all the stuff in query k
+ original_area = (mask_probs[k] >= mask_threshold).sum()
+ mask_exists = mask_k_area > 0 and original_area > 0
+
+ # Eliminate disconnected tiny segments
+ if mask_exists:
+ area_ratio = mask_k_area / original_area
+ if not area_ratio.item() > overlap_mask_area_threshold:
+ mask_exists = False
+
+ return mask_exists, mask_k
+
+
+# Copied from transformers.models.detr.image_processing_detr.compute_segments
+def compute_segments(
+ mask_probs,
+ pred_scores,
+ pred_labels,
+ mask_threshold: float = 0.5,
+ overlap_mask_area_threshold: float = 0.8,
+ label_ids_to_fuse: Optional[set[int]] = None,
+ target_size: Optional[tuple[int, int]] = None,
+):
+ height = mask_probs.shape[1] if target_size is None else target_size[0]
+ width = mask_probs.shape[2] if target_size is None else target_size[1]
+
+ segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
+ segments: list[dict] = []
+
+ if target_size is not None:
+ mask_probs = nn.functional.interpolate(
+ mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
+ )[0]
+
+ current_segment_id = 0
+
+ # Weigh each mask by its prediction score
+ mask_probs *= pred_scores.view(-1, 1, 1)
+ mask_labels = mask_probs.argmax(0) # [height, width]
+
+ # Keep track of instances of each class
+ stuff_memory_list: dict[str, int] = {}
+ for k in range(pred_labels.shape[0]):
+ pred_class = pred_labels[k].item()
+ should_fuse = pred_class in label_ids_to_fuse
+
+ # Check if mask exists and large enough to be a segment
+ mask_exists, mask_k = check_segment_validity(
+ mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
+ )
+
+ if mask_exists:
+ if pred_class in stuff_memory_list:
+ current_segment_id = stuff_memory_list[pred_class]
+ else:
+ current_segment_id += 1
+
+ # Add current object segment to final segmentation map
+ segmentation[mask_k] = current_segment_id
+ segment_score = round(pred_scores[k].item(), 6)
+ segments.append(
+ {
+ "id": current_segment_id,
+ "label_id": pred_class,
+ "was_fused": should_fuse,
+ "score": segment_score,
+ }
+ )
+ if should_fuse:
+ stuff_memory_list[pred_class] = current_segment_id
+
+ return segmentation, segments
+
+
+# TODO: (Amy) Move to image_transforms
+def convert_segmentation_map_to_binary_masks(
+ segmentation_map: np.ndarray,
+ instance_id_to_semantic_id: Optional[dict[int, int]] = None,
+ ignore_index: Optional[int] = None,
+ do_reduce_labels: bool = False,
+):
+ if do_reduce_labels and ignore_index is None:
+ raise ValueError("If `do_reduce_labels` is True, `ignore_index` must be provided.")
+
+ if do_reduce_labels:
+ segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
+
+ # Get unique ids (class or instance ids based on input)
+ all_labels = np.unique(segmentation_map)
+
+ # Drop background label if applicable
+ if ignore_index is not None:
+ all_labels = all_labels[all_labels != ignore_index]
+
+ # Generate a binary mask for each object instance
+ binary_masks = [(segmentation_map == i) for i in all_labels]
+
+ # Stack the binary masks
+ if binary_masks:
+ binary_masks = np.stack(binary_masks, axis=0)
+ else:
+ binary_masks = np.zeros((0, *segmentation_map.shape))
+
+ # Convert instance ids to class ids
+ if instance_id_to_semantic_id is not None:
+ labels = np.zeros(all_labels.shape[0])
+
+ for label in all_labels:
+ class_id = instance_id_to_semantic_id[label + 1 if do_reduce_labels else label]
+ labels[all_labels == label] = class_id - 1 if do_reduce_labels else class_id
+ else:
+ labels = all_labels
+
+ return binary_masks.astype(np.float32), labels.astype(np.int64)
+
+
+def get_maskformer_resize_output_image_size(
+ image: np.ndarray,
+ size: Union[int, tuple[int, int], list[int], tuple[int]],
+ max_size: Optional[int] = None,
+ size_divisor: int = 0,
+ default_to_square: bool = True,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> tuple[int, int]:
+ """
+ Computes the output size given the desired size.
+
+ Args:
+ image (`np.ndarray`):
+ The input image.
+ size (`int` or `tuple[int, int]` or `list[int]` or `tuple[int]`):
+ The size of the output image.
+ max_size (`int`, *optional*):
+ The maximum size of the output image.
+ size_divisor (`int`, *optional*, defaults to 0):
+ If `size_divisor` is given, the output image size will be divisible by the number.
+ default_to_square (`bool`, *optional*, defaults to `True`):
+ Whether to default to square if no size is provided.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
+
+ Returns:
+ `tuple[int, int]`: The output size.
+ """
+ output_size = get_resize_output_image_size(
+ input_image=image,
+ size=size,
+ default_to_square=default_to_square,
+ max_size=max_size,
+ input_data_format=input_data_format,
+ )
+
+ if size_divisor > 0:
+ height, width = output_size
+ height = int(math.ceil(height / size_divisor) * size_divisor)
+ width = int(math.ceil(width / size_divisor) * size_divisor)
+ output_size = (height, width)
+
+ return output_size
+
+
+@requires(backends=("vision",))
+class MaskFormerImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a MaskFormer image processor. The image processor can be used to prepare image(s) and optional targets
+ for the model.
+
+ This image processor inherits from [`BaseImageProcessor`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the input to a certain `size`.
+ size (`int`, *optional*, defaults to 800):
+ Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a
+ sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of
+ the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size *
+ height / width, size)`.
+ size_divisor (`int`, *optional*, defaults to 32):
+ Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in
+ Swin Transformer.
+ resample (`int`, *optional*, defaults to `Resampling.BILINEAR`):
+ An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
+ `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
+ `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
+ to `True`.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the input to a certain `scale`.
+ rescale_factor (`float`, *optional*, defaults to `1/ 255`):
+ Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input with mean and standard deviation.
+ image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
+ The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
+ image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
+ ImageNet std.
+ ignore_index (`int`, *optional*):
+ Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
+ denoted with 0 (background) will be replaced with `ignore_index`.
+ do_reduce_labels (`bool`, *optional*, defaults to `False`):
+ Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0
+ is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k).
+ The background label will be replaced by `ignore_index`.
+ num_labels (`int`, *optional*):
+ The number of labels in the segmentation map.
+ pad_size (`Dict[str, int]`, *optional*):
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
+ height and width in the batch.
+
+ """
+
+ model_input_names = ["pixel_values", "pixel_mask"]
+
+ @filter_out_non_signature_kwargs(extra=["max_size", *INIT_SERVICE_KWARGS])
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ size_divisor: int = 32,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ ignore_index: Optional[int] = None,
+ do_reduce_labels: bool = False,
+ num_labels: Optional[int] = None,
+ pad_size: Optional[dict[str, int]] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ # We make max_size a private attribute so we can pass it as a default value in the preprocess method whilst
+ # `size` can still be pass in as an int
+ self._max_size = kwargs.pop("max_size", 1333)
+
+ size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size}
+ size = get_size_dict(size, max_size=self._max_size, default_to_square=False)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.size_divisor = size_divisor
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
+ self.ignore_index = ignore_index
+ self.do_reduce_labels = do_reduce_labels
+ self.num_labels = num_labels
+ self.pad_size = pad_size
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary. This method calls the superclass method and then removes the
+ `_max_size` attribute from the dictionary.
+ """
+ image_processor_dict = super().to_dict()
+ image_processor_dict.pop("_max_size", None)
+ return image_processor_dict
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ size_divisor: int = 0,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format=None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an
+ int, smaller edge of the image will be matched to this number.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ The size of the output image.
+ size_divisor (`int`, *optional*, defaults to 0):
+ If `size_divisor` is given, the output image size will be divisible by the number.
+ resample (`PILImageResampling` resampling filter, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use when resizing the image.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+
+ # Deprecated, backward compatibility
+ max_size = kwargs.pop("max_size", None)
+
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
+ if "shortest_edge" in size and "longest_edge" in size:
+ size, max_size = size["shortest_edge"], size["longest_edge"]
+ elif "height" in size and "width" in size:
+ size = (size["height"], size["width"])
+ max_size = None
+ else:
+ raise ValueError(
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
+ f" {size.keys()}."
+ )
+ size = get_maskformer_resize_output_image_size(
+ image=image,
+ size=size,
+ max_size=max_size,
+ size_divisor=size_divisor,
+ default_to_square=False,
+ input_data_format=input_data_format,
+ )
+ image = resize(
+ image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
+ )
+ return image
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
+ def rescale(
+ self,
+ image: np.ndarray,
+ rescale_factor: float,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Rescale the image by the given factor. image = image * rescale_factor.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ rescale_factor (`float`):
+ The value to use for rescaling.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. If unset, is inferred from the input image. Can be
+ one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
+
+ def convert_segmentation_map_to_binary_masks(
+ self,
+ segmentation_map: np.ndarray,
+ instance_id_to_semantic_id: Optional[dict[int, int]] = None,
+ ignore_index: Optional[int] = None,
+ do_reduce_labels: bool = False,
+ ):
+ do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
+ ignore_index = ignore_index if ignore_index is not None else self.ignore_index
+ return convert_segmentation_map_to_binary_masks(
+ segmentation_map=segmentation_map,
+ instance_id_to_semantic_id=instance_id_to_semantic_id,
+ ignore_index=ignore_index,
+ do_reduce_labels=do_reduce_labels,
+ )
+
+ def __call__(self, images, segmentation_maps=None, **kwargs) -> BatchFeature:
+ return self.preprocess(images, segmentation_maps=segmentation_maps, **kwargs)
+
+ def _preprocess(
+ self,
+ image: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ size_divisor: Optional[int] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ if do_resize:
+ image = self.resize(
+ image, size=size, size_divisor=size_divisor, resample=resample, input_data_format=input_data_format
+ )
+ if do_rescale:
+ image = self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format)
+ if do_normalize:
+ image = self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ return image
+
+ def _preprocess_image(
+ self,
+ image: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ size_divisor: Optional[int] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """Preprocesses a single image."""
+ # All transformations expect numpy arrays.
+ image = to_numpy_array(image)
+ if do_rescale and is_scaled_image(image):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+ image = self._preprocess(
+ image=image,
+ do_resize=do_resize,
+ size=size,
+ size_divisor=size_divisor,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ input_data_format=input_data_format,
+ )
+ if data_format is not None:
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ return image
+
+ def _preprocess_mask(
+ self,
+ segmentation_map: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ size_divisor: int = 0,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """Preprocesses a single mask."""
+ segmentation_map = to_numpy_array(segmentation_map)
+ # Add channel dimension if missing - needed for certain transformations
+ if segmentation_map.ndim == 2:
+ added_channel_dim = True
+ segmentation_map = segmentation_map[None, ...]
+ input_data_format = ChannelDimension.FIRST
+ else:
+ added_channel_dim = False
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
+ # TODO: (Amy)
+ # Remork segmentation map processing to include reducing labels and resizing which doesn't
+ # drop segment IDs > 255.
+ segmentation_map = self._preprocess(
+ image=segmentation_map,
+ do_resize=do_resize,
+ resample=PILImageResampling.NEAREST,
+ size=size,
+ size_divisor=size_divisor,
+ do_rescale=False,
+ do_normalize=False,
+ input_data_format=input_data_format,
+ )
+ # Remove extra channel dimension if added for processing
+ if added_channel_dim:
+ segmentation_map = segmentation_map.squeeze(0)
+ return segmentation_map
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput] = None,
+ instance_id_to_semantic_id: Optional[dict[int, int]] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ size_divisor: Optional[int] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ ignore_index: Optional[int] = None,
+ do_reduce_labels: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ pad_size: Optional[dict[str, int]] = None,
+ ) -> BatchFeature:
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False, max_size=self._max_size)
+ size_divisor = size_divisor if size_divisor is not None else self.size_divisor
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ ignore_index = ignore_index if ignore_index is not None else self.ignore_index
+ do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
+ pad_size = self.pad_size if pad_size is None else pad_size
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if segmentation_maps is not None and not valid_images(segmentation_maps):
+ raise ValueError(
+ "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ images = make_flat_list_of_images(images)
+ if segmentation_maps is not None:
+ segmentation_maps = make_flat_list_of_images(segmentation_maps, expected_ndims=2)
+
+ if segmentation_maps is not None and len(images) != len(segmentation_maps):
+ raise ValueError("Images and segmentation maps must have the same length.")
+
+ images = [
+ self._preprocess_image(
+ image,
+ do_resize=do_resize,
+ size=size,
+ size_divisor=size_divisor,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ for image in images
+ ]
+
+ if segmentation_maps is not None:
+ segmentation_maps = [
+ self._preprocess_mask(
+ segmentation_map, do_resize, size, size_divisor, input_data_format=input_data_format
+ )
+ for segmentation_map in segmentation_maps
+ ]
+ encoded_inputs = self.encode_inputs(
+ images,
+ segmentation_maps,
+ instance_id_to_semantic_id,
+ ignore_index,
+ do_reduce_labels,
+ return_tensors,
+ input_data_format=data_format,
+ pad_size=pad_size,
+ )
+ return encoded_inputs
+
+ # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor._pad_image
+ def _pad_image(
+ self,
+ image: np.ndarray,
+ output_size: tuple[int, int],
+ constant_values: Union[float, Iterable[float]] = 0,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pad an image with zeros to the given size.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ output_height, output_width = output_size
+
+ pad_bottom = output_height - input_height
+ pad_right = output_width - input_width
+ padding = ((0, pad_bottom), (0, pad_right))
+ padded_image = pad(
+ image,
+ padding,
+ mode=PaddingMode.CONSTANT,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ return padded_image
+
+ def pad(
+ self,
+ images: list[np.ndarray],
+ constant_values: Union[float, Iterable[float]] = 0,
+ return_pixel_mask: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ pad_size: Optional[dict[str, int]] = None,
+ ) -> BatchFeature:
+ """
+ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
+ in the batch and optionally returns their corresponding pixel mask.
+
+ Args:
+ image (`np.ndarray`):
+ Image to pad.
+ constant_values (`float` or `Iterable[float]`, *optional*):
+ The value to use for the padding if `mode` is `"constant"`.
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
+ Whether to return a pixel mask.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ pad_size (`Dict[str, int]`, *optional*):
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
+ height and width in the batch.
+ """
+ pad_size = pad_size if pad_size is not None else self.pad_size
+ if pad_size is not None:
+ padded_size = (pad_size["height"], pad_size["width"])
+ else:
+ padded_size = get_max_height_width(images, input_data_format=input_data_format)
+
+ padded_images = [
+ self._pad_image(
+ image,
+ padded_size,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ for image in images
+ ]
+ data = {"pixel_values": padded_images}
+
+ if return_pixel_mask:
+ masks = [
+ make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format)
+ for image in images
+ ]
+ data["pixel_mask"] = masks
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def encode_inputs(
+ self,
+ pixel_values_list: list[ImageInput],
+ segmentation_maps: Optional[ImageInput] = None,
+ instance_id_to_semantic_id: Optional[Union[list[dict[int, int]], dict[int, int]]] = None,
+ ignore_index: Optional[int] = None,
+ do_reduce_labels: bool = False,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ pad_size: Optional[dict[str, int]] = None,
+ ):
+ """
+ Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
+
+ MaskFormer addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps
+ will be converted to lists of binary masks and their respective labels. Let's see an example, assuming
+ `segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels =
+ [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for
+ each mask.
+
+ Args:
+ pixel_values_list (`list[ImageInput]`):
+ List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
+ width)`.
+
+ segmentation_maps (`ImageInput`, *optional*):
+ The corresponding semantic segmentation maps with the pixel-wise annotations.
+
+ (`bool`, *optional*, defaults to `True`):
+ Whether or not to pad images up to the largest image in a batch and create a pixel mask.
+
+ If left to the default, will return a pixel mask that is:
+
+ - 1 for pixels that are real (i.e. **not masked**),
+ - 0 for pixels that are padding (i.e. **masked**).
+
+ instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*):
+ A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an
+ instance segmentation map where each pixel represents an instance id. Can be provided as a single
+ dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map
+ instance ids in each image separately.
+
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
+ If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
+ objects.
+
+ pad_size (`Dict[str, int]`, *optional*):
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
+ height and width in the batch.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model.
+ - **pixel_mask** -- Pixel mask to be fed to a model (when `=True` or if `pixel_mask` is in
+ `self.model_input_names`).
+ - **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model
+ (when `annotations` are provided).
+ - **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when
+ `annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of
+ `mask_labels[i][j]` if `class_labels[i][j]`.
+ """
+ ignore_index = self.ignore_index if ignore_index is None else ignore_index
+ do_reduce_labels = self.do_reduce_labels if do_reduce_labels is None else do_reduce_labels
+
+ pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(pixel_values_list[0])
+
+ encoded_inputs = self.pad(
+ pixel_values_list, return_tensors=return_tensors, input_data_format=input_data_format, pad_size=pad_size
+ )
+
+ if segmentation_maps is not None:
+ mask_labels = []
+ class_labels = []
+ pad_size = get_max_height_width(pixel_values_list, input_data_format=input_data_format)
+ # Convert to list of binary masks and labels
+ for idx, segmentation_map in enumerate(segmentation_maps):
+ segmentation_map = to_numpy_array(segmentation_map)
+ if isinstance(instance_id_to_semantic_id, list):
+ instance_id = instance_id_to_semantic_id[idx]
+ else:
+ instance_id = instance_id_to_semantic_id
+ # Use instance2class_id mapping per image
+ masks, classes = self.convert_segmentation_map_to_binary_masks(
+ segmentation_map, instance_id, ignore_index=ignore_index, do_reduce_labels=do_reduce_labels
+ )
+ # We add an axis to make them compatible with the transformations library
+ # this will be removed in the future
+ if masks.shape[0] > 0:
+ masks = [mask[None, ...] for mask in masks]
+ masks = [
+ self._pad_image(
+ image=mask,
+ output_size=pad_size,
+ constant_values=ignore_index,
+ input_data_format=ChannelDimension.FIRST,
+ )
+ for mask in masks
+ ]
+ masks = np.concatenate(masks, axis=0)
+ else:
+ masks = np.zeros((0, *pad_size), dtype=np.float32)
+ mask_labels.append(torch.from_numpy(masks))
+ class_labels.append(torch.from_numpy(classes))
+
+ # we cannot batch them since they don't share a common class size
+ encoded_inputs["mask_labels"] = mask_labels
+ encoded_inputs["class_labels"] = class_labels
+
+ return encoded_inputs
+
+ def post_process_segmentation(
+ self, outputs: "MaskFormerForInstanceSegmentationOutput", target_size: Optional[tuple[int, int]] = None
+ ) -> "torch.Tensor":
+ """
+ Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image segmentation predictions. Only
+ supports PyTorch.
+
+ Args:
+ outputs ([`MaskFormerForInstanceSegmentationOutput`]):
+ The outputs from [`MaskFormerForInstanceSegmentation`].
+
+ target_size (`tuple[int, int]`, *optional*):
+ If set, the `masks_queries_logits` will be resized to `target_size`.
+
+ Returns:
+ `torch.Tensor`:
+ A tensor of shape (`batch_size, num_class_labels, height, width`).
+ """
+ warnings.warn(
+ "`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
+ " `post_process_instance_segmentation`",
+ FutureWarning,
+ )
+
+ # class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1]
+ class_queries_logits = outputs.class_queries_logits
+ # masks_queries_logits has shape [BATCH, QUERIES, HEIGHT, WIDTH]
+ masks_queries_logits = outputs.masks_queries_logits
+ if target_size is not None:
+ masks_queries_logits = torch.nn.functional.interpolate(
+ masks_queries_logits,
+ size=target_size,
+ mode="bilinear",
+ align_corners=False,
+ )
+ # remove the null class `[..., :-1]`
+ masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
+ # mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH]
+ masks_probs = masks_queries_logits.sigmoid()
+ # now we want to sum over the queries,
+ # $ out_{c,h,w} = \sum_q p_{q,c} * m_{q,h,w} $
+ # where $ softmax(p) \in R^{q, c} $ is the mask classes
+ # and $ sigmoid(m) \in R^{q, h, w}$ is the mask probabilities
+ # b(atch)q(uery)c(lasses), b(atch)q(uery)h(eight)w(idth)
+ segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
+
+ return segmentation
+
+ def post_process_semantic_segmentation(
+ self, outputs, target_sizes: Optional[list[tuple[int, int]]] = None
+ ) -> "torch.Tensor":
+ """
+ Converts the output of [`MaskFormerForInstanceSegmentation`] into semantic segmentation maps. Only supports
+ PyTorch.
+
+ Args:
+ outputs ([`MaskFormerForInstanceSegmentation`]):
+ Raw outputs of the model.
+ target_sizes (`list[tuple[int, int]]`, *optional*):
+ List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested
+ final size (height, width) of each prediction. If left to None, predictions will not be resized.
+ Returns:
+ `list[torch.Tensor]`:
+ A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
+ corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
+ `torch.Tensor` correspond to a semantic class id.
+ """
+ class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
+ masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
+
+ # Remove the null class `[..., :-1]`
+ masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
+ masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
+
+ # Semantic segmentation logits of shape (batch_size, num_classes, height, width)
+ segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
+ batch_size = class_queries_logits.shape[0]
+
+ # Resize logits and compute semantic segmentation maps
+ if target_sizes is not None:
+ if batch_size != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ semantic_segmentation = []
+ for idx in range(batch_size):
+ resized_logits = torch.nn.functional.interpolate(
+ segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
+ )
+ semantic_map = resized_logits[0].argmax(dim=0)
+ semantic_segmentation.append(semantic_map)
+ else:
+ semantic_segmentation = segmentation.argmax(dim=1)
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
+
+ return semantic_segmentation
+
+ def post_process_instance_segmentation(
+ self,
+ outputs,
+ threshold: float = 0.5,
+ mask_threshold: float = 0.5,
+ overlap_mask_area_threshold: float = 0.8,
+ target_sizes: Optional[list[tuple[int, int]]] = None,
+ return_coco_annotation: Optional[bool] = False,
+ return_binary_maps: Optional[bool] = False,
+ ) -> list[dict]:
+ """
+ Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into instance segmentation predictions. Only
+ supports PyTorch. If instances could overlap, set either return_coco_annotation or return_binary_maps
+ to `True` to get the correct segmentation result.
+
+ Args:
+ outputs ([`MaskFormerForInstanceSegmentation`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*, defaults to 0.5):
+ The probability score threshold to keep predicted instance masks.
+ mask_threshold (`float`, *optional*, defaults to 0.5):
+ Threshold to use when turning the predicted masks into binary values.
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
+ instance mask.
+ target_sizes (`list[Tuple]`, *optional*):
+ List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested
+ final size (height, width) of each prediction. If left to None, predictions will not be resized.
+ return_coco_annotation (`bool`, *optional*, defaults to `False`):
+ If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format.
+ return_binary_maps (`bool`, *optional*, defaults to `False`):
+ If set to `True`, segmentation maps are returned as a concatenated tensor of binary segmentation maps
+ (one per detected instance).
+ Returns:
+ `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
+ - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id`, or
+ `list[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
+ `True`, or a tensor of shape `(num_instances, height, width)` if return_binary_maps is set to `True`.
+ Set to `None` if no mask if found above `threshold`.
+ - **segments_info** -- A dictionary that contains additional information on each segment.
+ - **id** -- An integer representing the `segment_id`.
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
+ - **score** -- Prediction score of segment with `segment_id`.
+ """
+ if return_coco_annotation and return_binary_maps:
+ raise ValueError("return_coco_annotation and return_binary_maps can not be both set to True.")
+
+ # [batch_size, num_queries, num_classes+1]
+ class_queries_logits = outputs.class_queries_logits
+ # [batch_size, num_queries, height, width]
+ masks_queries_logits = outputs.masks_queries_logits
+
+ device = masks_queries_logits.device
+ num_classes = class_queries_logits.shape[-1] - 1
+ num_queries = class_queries_logits.shape[-2]
+
+ # Loop over items in batch size
+ results: list[dict[str, TensorType]] = []
+
+ for i in range(class_queries_logits.shape[0]):
+ mask_pred = masks_queries_logits[i]
+ mask_cls = class_queries_logits[i]
+
+ scores = torch.nn.functional.softmax(mask_cls, dim=-1)[:, :-1]
+ labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
+
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
+ labels_per_image = labels[topk_indices]
+
+ topk_indices = torch.div(topk_indices, num_classes, rounding_mode="floor")
+ mask_pred = mask_pred[topk_indices]
+ pred_masks = (mask_pred > 0).float()
+
+ # Calculate average mask prob
+ mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / (
+ pred_masks.flatten(1).sum(1) + 1e-6
+ )
+ pred_scores = scores_per_image * mask_scores_per_image
+ pred_classes = labels_per_image
+
+ segmentation = torch.zeros(masks_queries_logits.shape[2:]) - 1
+ if target_sizes is not None:
+ segmentation = torch.zeros(target_sizes[i]) - 1
+ pred_masks = torch.nn.functional.interpolate(
+ pred_masks.unsqueeze(0), size=target_sizes[i], mode="nearest"
+ )[0]
+
+ instance_maps, segments = [], []
+ current_segment_id = 0
+ for j in range(num_queries):
+ score = pred_scores[j].item()
+
+ if not torch.all(pred_masks[j] == 0) and score >= threshold:
+ segmentation[pred_masks[j] == 1] = current_segment_id
+ segments.append(
+ {
+ "id": current_segment_id,
+ "label_id": pred_classes[j].item(),
+ "was_fused": False,
+ "score": round(score, 6),
+ }
+ )
+ current_segment_id += 1
+ instance_maps.append(pred_masks[j])
+
+ # Return segmentation map in run-length encoding (RLE) format
+ if return_coco_annotation:
+ segmentation = convert_segmentation_to_rle(segmentation)
+
+ # Return a concatenated tensor of binary instance maps
+ if return_binary_maps and len(instance_maps) != 0:
+ segmentation = torch.stack(instance_maps, dim=0)
+
+ results.append({"segmentation": segmentation, "segments_info": segments})
+ return results
+
+ def post_process_panoptic_segmentation(
+ self,
+ outputs,
+ threshold: float = 0.5,
+ mask_threshold: float = 0.5,
+ overlap_mask_area_threshold: float = 0.8,
+ label_ids_to_fuse: Optional[set[int]] = None,
+ target_sizes: Optional[list[tuple[int, int]]] = None,
+ ) -> list[dict]:
+ """
+ Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image panoptic segmentation
+ predictions. Only supports PyTorch.
+
+ Args:
+ outputs ([`MaskFormerForInstanceSegmentationOutput`]):
+ The outputs from [`MaskFormerForInstanceSegmentation`].
+ threshold (`float`, *optional*, defaults to 0.5):
+ The probability score threshold to keep predicted instance masks.
+ mask_threshold (`float`, *optional*, defaults to 0.5):
+ Threshold to use when turning the predicted masks into binary values.
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
+ instance mask.
+ label_ids_to_fuse (`Set[int]`, *optional*):
+ The labels in this state will have all their instances be fused together. For instance we could say
+ there can only be one sky in an image, but several persons, so the label ID for sky would be in that
+ set, but not the one for person.
+ target_sizes (`list[Tuple]`, *optional*):
+ List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested
+ final size (height, width) of each prediction in batch. If left to None, predictions will not be
+ resized.
+
+ Returns:
+ `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
+ - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set
+ to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized
+ to the corresponding `target_sizes` entry.
+ - **segments_info** -- A dictionary that contains additional information on each segment.
+ - **id** -- an integer representing the `segment_id`.
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
+ - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
+ Multiple instances of the same class / label were fused and assigned a single `segment_id`.
+ - **score** -- Prediction score of segment with `segment_id`.
+ """
+
+ if label_ids_to_fuse is None:
+ logger.warning("`label_ids_to_fuse` unset. No instance will be fused.")
+ label_ids_to_fuse = set()
+
+ class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
+ masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
+
+ batch_size = class_queries_logits.shape[0]
+ num_labels = class_queries_logits.shape[-1] - 1
+
+ mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
+
+ # Predicted label and score of each query (batch_size, num_queries)
+ pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
+
+ # Loop over items in batch size
+ results: list[dict[str, TensorType]] = []
+
+ for i in range(batch_size):
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
+ mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
+ )
+
+ # No mask found
+ if mask_probs_item.shape[0] <= 0:
+ height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
+ segmentation = torch.zeros((height, width)) - 1
+ results.append({"segmentation": segmentation, "segments_info": []})
+ continue
+
+ # Get segmentation map and segment information of batch item
+ target_size = target_sizes[i] if target_sizes is not None else None
+ segmentation, segments = compute_segments(
+ mask_probs=mask_probs_item,
+ pred_scores=pred_scores_item,
+ pred_labels=pred_labels_item,
+ mask_threshold=mask_threshold,
+ overlap_mask_area_threshold=overlap_mask_area_threshold,
+ label_ids_to_fuse=label_ids_to_fuse,
+ target_size=target_size,
+ )
+
+ results.append({"segmentation": segmentation, "segments_info": segments})
+ return results
+
+
+__all__ = ["MaskFormerImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/image_processing_maskformer_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/image_processing_maskformer_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e15486cfa3524347f0fab1dc94193fd73729398
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/image_processing_maskformer_fast.py
@@ -0,0 +1,735 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for MaskFormer."""
+
+import math
+import warnings
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+import torch
+from torch import nn
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils import BatchFeature, get_size_dict
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ SizeDict,
+ get_image_size_for_max_height_width,
+ get_max_height_width,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ auto_docstring,
+ logging,
+)
+from .image_processing_maskformer import (
+ compute_segments,
+ convert_segmentation_to_rle,
+ get_size_with_aspect_ratio,
+ remove_low_and_no_objects,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+if TYPE_CHECKING:
+ from transformers import MaskFormerForInstanceSegmentationOutput
+
+
+def convert_segmentation_map_to_binary_masks_fast(
+ segmentation_map: "torch.Tensor",
+ instance_id_to_semantic_id: Optional[dict[int, int]] = None,
+ ignore_index: Optional[int] = None,
+ do_reduce_labels: bool = False,
+):
+ if do_reduce_labels and ignore_index is None:
+ raise ValueError("If `do_reduce_labels` is True, `ignore_index` must be provided.")
+
+ if do_reduce_labels:
+ segmentation_map = torch.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
+
+ all_labels = torch.unique(segmentation_map)
+
+ if ignore_index is not None:
+ all_labels = all_labels[all_labels != ignore_index] # drop background label if applicable
+
+ binary_masks = [(segmentation_map == i) for i in all_labels]
+ if binary_masks:
+ binary_masks = torch.stack(binary_masks, dim=0)
+ else:
+ binary_masks = torch.zeros((0, *segmentation_map.shape), device=segmentation_map.device)
+
+ # Convert instance ids to class ids
+ if instance_id_to_semantic_id is not None:
+ labels = torch.zeros(all_labels.shape[0], device=segmentation_map.device)
+
+ for i, label in enumerate(all_labels):
+ class_id = instance_id_to_semantic_id[(label.item() + 1 if do_reduce_labels else label.item())]
+ labels[i] = class_id - 1 if do_reduce_labels else class_id
+ else:
+ labels = all_labels
+ return binary_masks.float(), labels.long()
+
+
+class MaskFormerFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ r"""
+ size_divisor (`int`, *optional*, defaults to 32):
+ Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in
+ Swin Transformer.
+ ignore_index (`int`, *optional*):
+ Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
+ denoted with 0 (background) will be replaced with `ignore_index`.
+ do_reduce_labels (`bool`, *optional*, defaults to `False`):
+ Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0
+ is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k).
+ The background label will be replaced by `ignore_index`.
+ num_labels (`int`, *optional*):
+ The number of labels in the segmentation map.
+ """
+
+ size_divisor: Optional[int]
+ ignore_index: Optional[int]
+ do_reduce_labels: Optional[bool]
+ num_labels: Optional[int]
+
+
+@auto_docstring
+class MaskFormerImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_DEFAULT_MEAN
+ image_std = IMAGENET_DEFAULT_STD
+ size = {"shortest_edge": 800, "longest_edge": 1333}
+ default_to_square = False
+ do_resize = True
+ do_rescale = True
+ rescale_factor = 1 / 255
+ do_normalize = True
+ do_pad = True
+ model_input_names = ["pixel_values", "pixel_mask"]
+ size_divisor = 32
+ do_reduce_labels = False
+ valid_kwargs = MaskFormerFastImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[MaskFormerFastImageProcessorKwargs]) -> None:
+ if "pad_and_return_pixel_mask" in kwargs:
+ kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
+
+ size = kwargs.pop("size", None)
+ max_size = kwargs.pop("max_size", None)
+
+ if size is None and max_size is not None:
+ size = self.size
+ size["longest_edge"] = max_size
+ elif size is None:
+ size = self.size
+
+ self.size = get_size_dict(size, max_size=max_size, default_to_square=False)
+
+ super().__init__(**kwargs)
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary. This method calls the superclass method and then removes the
+ `_max_size` attribute from the dictionary.
+ """
+ image_processor_dict = super().to_dict()
+ image_processor_dict.pop("_max_size", None)
+ return image_processor_dict
+
+ def reduce_label(self, labels: list["torch.Tensor"]):
+ for idx in range(len(labels)):
+ label = labels[idx]
+ label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label)
+ label = label - 1
+ label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label)
+ labels[idx] = label
+
+ def resize(
+ self,
+ image: torch.Tensor,
+ size: SizeDict,
+ size_divisor: int = 0,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
+ int, smaller edge of the image will be matched to this number.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to resize.
+ size (`SizeDict`):
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
+ Do NOT keep the aspect ratio.
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
+ less or equal to `longest_edge`.
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
+ `max_width`.
+ size_divisor (`int`, *optional*, defaults to 0):
+ If `size_divisor` is given, the output image size will be divisible by the number.
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ Resampling filter to use if resizing the image.
+ """
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
+ if size.shortest_edge and size.longest_edge:
+ # Resize the image so that the shortest edge or the longest edge is of the given size
+ # while maintaining the aspect ratio of the original image.
+ new_size = get_size_with_aspect_ratio(
+ image.size()[-2:],
+ size["shortest_edge"],
+ size["longest_edge"],
+ )
+ elif size.max_height and size.max_width:
+ new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"])
+ elif size.height and size.width:
+ new_size = (size["height"], size["width"])
+ else:
+ raise ValueError(
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
+ f" {size.keys()}."
+ )
+ if size_divisor > 0:
+ height, width = new_size
+ height = int(math.ceil(height / size_divisor) * size_divisor)
+ width = int(math.ceil(width / size_divisor) * size_divisor)
+ new_size = (height, width)
+
+ image = F.resize(
+ image,
+ size=new_size,
+ interpolation=interpolation,
+ **kwargs,
+ )
+ return image
+
+ def pad(
+ self,
+ images: torch.Tensor,
+ padded_size: tuple[int, int],
+ segmentation_maps: Optional[torch.Tensor] = None,
+ fill: int = 0,
+ ignore_index: int = 255,
+ ) -> BatchFeature:
+ original_size = images.size()[-2:]
+ padding_bottom = padded_size[0] - original_size[0]
+ padding_right = padded_size[1] - original_size[1]
+ if padding_bottom < 0 or padding_right < 0:
+ raise ValueError(
+ f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
+ f"original size. Got padded size: {padded_size}, original size: {original_size}."
+ )
+ if original_size != padded_size:
+ padding = [0, 0, padding_right, padding_bottom]
+ images = F.pad(images, padding, fill=fill)
+ if segmentation_maps is not None:
+ segmentation_maps = F.pad(segmentation_maps, padding, fill=ignore_index)
+
+ # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+ pixel_mask = torch.zeros((images.shape[0], *padded_size), dtype=torch.int64, device=images.device)
+ pixel_mask[:, : original_size[0], : original_size[1]] = 1
+
+ return images, pixel_mask, segmentation_maps
+
+ @auto_docstring
+ def preprocess(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput] = None,
+ instance_id_to_semantic_id: Optional[Union[list[dict[int, int]], dict[int, int]]] = None,
+ **kwargs: Unpack[MaskFormerFastImageProcessorKwargs],
+ ) -> BatchFeature:
+ r"""
+ segmentation_maps (`ImageInput`, *optional*):
+ The segmentation maps.
+ instance_id_to_semantic_id (`Union[list[dict[int, int]], dict[int, int]]`, *optional*):
+ A mapping from instance IDs to semantic IDs.
+ """
+ return super().preprocess(
+ images,
+ segmentation_maps,
+ instance_id_to_semantic_id,
+ **kwargs,
+ )
+
+ def _preprocess_image_like_inputs(
+ self,
+ images: ImageInput,
+ segmentation_maps: ImageInput,
+ instance_id_to_semantic_id: Optional[Union[list[dict[int, int]], dict[int, int]]],
+ do_convert_rgb: bool,
+ input_data_format: ChannelDimension,
+ device: Optional[Union[str, "torch.device"]] = None,
+ **kwargs: Unpack[MaskFormerFastImageProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Preprocess image-like inputs.
+ To be overridden by subclasses when image-like inputs other than images should be processed.
+ It can be used for segmentation maps, depth maps, etc.
+ """
+ # Prepare input images
+ images = self._prepare_image_like_inputs(
+ images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
+ )
+ if segmentation_maps is not None:
+ segmentation_maps = self._prepare_image_like_inputs(
+ images=segmentation_maps,
+ expected_ndims=2,
+ do_convert_rgb=False,
+ input_data_format=ChannelDimension.FIRST,
+ )
+ return self._preprocess(images, segmentation_maps, instance_id_to_semantic_id, **kwargs)
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ segmentation_maps: Optional["torch.Tensor"],
+ instance_id_to_semantic_id: Optional[dict[int, int]],
+ do_resize: Optional[bool],
+ size: Optional[SizeDict],
+ pad_size: Optional[SizeDict],
+ size_divisor: Optional[int],
+ interpolation: Optional[Union["PILImageResampling", "F.InterpolationMode"]],
+ do_rescale: Optional[bool],
+ rescale_factor: Optional[float],
+ do_normalize: Optional[bool],
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ ignore_index: Optional[int],
+ do_reduce_labels: Optional[bool],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ if segmentation_maps is not None and len(images) != len(segmentation_maps):
+ raise ValueError("Images and segmentation maps must have the same length.")
+
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ if segmentation_maps is not None:
+ grouped_segmentation_maps, grouped_segmentation_maps_index = group_images_by_shape(
+ segmentation_maps, disable_grouping=disable_grouping
+ )
+ resized_segmentation_maps_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(
+ image=stacked_images, size=size, size_divisor=size_divisor, interpolation=interpolation
+ )
+ if segmentation_maps is not None:
+ stacked_segmentation_maps = self.resize(
+ image=grouped_segmentation_maps[shape],
+ size=size,
+ size_divisor=size_divisor,
+ interpolation=F.InterpolationMode.NEAREST_EXACT,
+ )
+ resized_images_grouped[shape] = stacked_images
+ if segmentation_maps is not None:
+ resized_segmentation_maps_grouped[shape] = stacked_segmentation_maps
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+ if segmentation_maps is not None:
+ resized_segmentation_maps = reorder_images(
+ resized_segmentation_maps_grouped, grouped_segmentation_maps_index
+ )
+ if pad_size is not None:
+ padded_size = (pad_size.height, pad_size.width)
+ else:
+ padded_size = get_max_height_width(resized_images)
+
+ if segmentation_maps is not None:
+ mask_labels = []
+ class_labels = []
+ # Convert to list of binary masks and labels
+ for idx, segmentation_map in enumerate(resized_segmentation_maps):
+ if isinstance(instance_id_to_semantic_id, list):
+ instance_id = instance_id_to_semantic_id[idx]
+ else:
+ instance_id = instance_id_to_semantic_id
+ # Use instance2class_id mapping per image
+ masks, classes = convert_segmentation_map_to_binary_masks_fast(
+ segmentation_map.squeeze(0),
+ instance_id,
+ ignore_index=ignore_index,
+ do_reduce_labels=do_reduce_labels,
+ )
+ mask_labels.append(masks)
+ class_labels.append(classes)
+
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ processed_pixel_masks_grouped = {}
+ if segmentation_maps is not None:
+ grouped_segmentation_maps, grouped_segmentation_maps_index = group_images_by_shape(
+ mask_labels, disable_grouping=disable_grouping
+ )
+ processed_segmentation_maps_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ padded_images, pixel_masks, padded_segmentation_maps = self.pad(
+ images=stacked_images,
+ segmentation_maps=grouped_segmentation_maps[shape] if segmentation_maps is not None else None,
+ padded_size=padded_size,
+ ignore_index=ignore_index,
+ )
+ processed_images_grouped[shape] = padded_images
+ processed_pixel_masks_grouped[shape] = pixel_masks
+ if segmentation_maps is not None:
+ processed_segmentation_maps_grouped[shape] = padded_segmentation_maps.squeeze(1)
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_pixel_masks = reorder_images(processed_pixel_masks_grouped, grouped_images_index)
+ encoded_inputs = BatchFeature(
+ data={
+ "pixel_values": torch.stack(processed_images, dim=0) if return_tensors else processed_images,
+ "pixel_mask": torch.stack(processed_pixel_masks, dim=0) if return_tensors else processed_pixel_masks,
+ },
+ tensor_type=return_tensors,
+ )
+ if segmentation_maps is not None:
+ mask_labels = reorder_images(processed_segmentation_maps_grouped, grouped_segmentation_maps_index)
+ # we cannot batch them since they don't share a common class size
+ encoded_inputs["mask_labels"] = mask_labels
+ encoded_inputs["class_labels"] = class_labels
+
+ return encoded_inputs
+
+ # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_segmentation
+ def post_process_segmentation(
+ self, outputs: "MaskFormerForInstanceSegmentationOutput", target_size: Optional[tuple[int, int]] = None
+ ) -> "torch.Tensor":
+ """
+ Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image segmentation predictions. Only
+ supports PyTorch.
+
+ Args:
+ outputs ([`MaskFormerForInstanceSegmentationOutput`]):
+ The outputs from [`MaskFormerForInstanceSegmentation`].
+
+ target_size (`tuple[int, int]`, *optional*):
+ If set, the `masks_queries_logits` will be resized to `target_size`.
+
+ Returns:
+ `torch.Tensor`:
+ A tensor of shape (`batch_size, num_class_labels, height, width`).
+ """
+ warnings.warn(
+ "`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
+ " `post_process_instance_segmentation`",
+ FutureWarning,
+ )
+
+ # class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1]
+ class_queries_logits = outputs.class_queries_logits
+ # masks_queries_logits has shape [BATCH, QUERIES, HEIGHT, WIDTH]
+ masks_queries_logits = outputs.masks_queries_logits
+ if target_size is not None:
+ masks_queries_logits = torch.nn.functional.interpolate(
+ masks_queries_logits,
+ size=target_size,
+ mode="bilinear",
+ align_corners=False,
+ )
+ # remove the null class `[..., :-1]`
+ masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
+ # mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH]
+ masks_probs = masks_queries_logits.sigmoid()
+ # now we want to sum over the queries,
+ # $ out_{c,h,w} = \sum_q p_{q,c} * m_{q,h,w} $
+ # where $ softmax(p) \in R^{q, c} $ is the mask classes
+ # and $ sigmoid(m) \in R^{q, h, w}$ is the mask probabilities
+ # b(atch)q(uery)c(lasses), b(atch)q(uery)h(eight)w(idth)
+ segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
+
+ return segmentation
+
+ # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_semantic_segmentation
+ def post_process_semantic_segmentation(
+ self, outputs, target_sizes: Optional[list[tuple[int, int]]] = None
+ ) -> "torch.Tensor":
+ """
+ Converts the output of [`MaskFormerForInstanceSegmentation`] into semantic segmentation maps. Only supports
+ PyTorch.
+
+ Args:
+ outputs ([`MaskFormerForInstanceSegmentation`]):
+ Raw outputs of the model.
+ target_sizes (`list[tuple[int, int]]`, *optional*):
+ List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested
+ final size (height, width) of each prediction. If left to None, predictions will not be resized.
+ Returns:
+ `list[torch.Tensor]`:
+ A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
+ corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
+ `torch.Tensor` correspond to a semantic class id.
+ """
+ class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
+ masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
+
+ # Remove the null class `[..., :-1]`
+ masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
+ masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
+
+ # Semantic segmentation logits of shape (batch_size, num_classes, height, width)
+ segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
+ batch_size = class_queries_logits.shape[0]
+
+ # Resize logits and compute semantic segmentation maps
+ if target_sizes is not None:
+ if batch_size != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ semantic_segmentation = []
+ for idx in range(batch_size):
+ resized_logits = torch.nn.functional.interpolate(
+ segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
+ )
+ semantic_map = resized_logits[0].argmax(dim=0)
+ semantic_segmentation.append(semantic_map)
+ else:
+ semantic_segmentation = segmentation.argmax(dim=1)
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
+
+ return semantic_segmentation
+
+ # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_instance_segmentation
+ def post_process_instance_segmentation(
+ self,
+ outputs,
+ threshold: float = 0.5,
+ mask_threshold: float = 0.5,
+ overlap_mask_area_threshold: float = 0.8,
+ target_sizes: Optional[list[tuple[int, int]]] = None,
+ return_coco_annotation: Optional[bool] = False,
+ return_binary_maps: Optional[bool] = False,
+ ) -> list[dict]:
+ """
+ Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into instance segmentation predictions. Only
+ supports PyTorch. If instances could overlap, set either return_coco_annotation or return_binary_maps
+ to `True` to get the correct segmentation result.
+
+ Args:
+ outputs ([`MaskFormerForInstanceSegmentation`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*, defaults to 0.5):
+ The probability score threshold to keep predicted instance masks.
+ mask_threshold (`float`, *optional*, defaults to 0.5):
+ Threshold to use when turning the predicted masks into binary values.
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
+ instance mask.
+ target_sizes (`list[Tuple]`, *optional*):
+ List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested
+ final size (height, width) of each prediction. If left to None, predictions will not be resized.
+ return_coco_annotation (`bool`, *optional*, defaults to `False`):
+ If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format.
+ return_binary_maps (`bool`, *optional*, defaults to `False`):
+ If set to `True`, segmentation maps are returned as a concatenated tensor of binary segmentation maps
+ (one per detected instance).
+ Returns:
+ `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
+ - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id`, or
+ `list[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
+ `True`, or a tensor of shape `(num_instances, height, width)` if return_binary_maps is set to `True`.
+ Set to `None` if no mask if found above `threshold`.
+ - **segments_info** -- A dictionary that contains additional information on each segment.
+ - **id** -- An integer representing the `segment_id`.
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
+ - **score** -- Prediction score of segment with `segment_id`.
+ """
+ if return_coco_annotation and return_binary_maps:
+ raise ValueError("return_coco_annotation and return_binary_maps can not be both set to True.")
+
+ # [batch_size, num_queries, num_classes+1]
+ class_queries_logits = outputs.class_queries_logits
+ # [batch_size, num_queries, height, width]
+ masks_queries_logits = outputs.masks_queries_logits
+
+ device = masks_queries_logits.device
+ num_classes = class_queries_logits.shape[-1] - 1
+ num_queries = class_queries_logits.shape[-2]
+
+ # Loop over items in batch size
+ results: list[dict[str, TensorType]] = []
+
+ for i in range(class_queries_logits.shape[0]):
+ mask_pred = masks_queries_logits[i]
+ mask_cls = class_queries_logits[i]
+
+ scores = torch.nn.functional.softmax(mask_cls, dim=-1)[:, :-1]
+ labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
+
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
+ labels_per_image = labels[topk_indices]
+
+ topk_indices = torch.div(topk_indices, num_classes, rounding_mode="floor")
+ mask_pred = mask_pred[topk_indices]
+ pred_masks = (mask_pred > 0).float()
+
+ # Calculate average mask prob
+ mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / (
+ pred_masks.flatten(1).sum(1) + 1e-6
+ )
+ pred_scores = scores_per_image * mask_scores_per_image
+ pred_classes = labels_per_image
+
+ segmentation = torch.zeros(masks_queries_logits.shape[2:]) - 1
+ if target_sizes is not None:
+ segmentation = torch.zeros(target_sizes[i]) - 1
+ pred_masks = torch.nn.functional.interpolate(
+ pred_masks.unsqueeze(0), size=target_sizes[i], mode="nearest"
+ )[0]
+
+ instance_maps, segments = [], []
+ current_segment_id = 0
+ for j in range(num_queries):
+ score = pred_scores[j].item()
+
+ if not torch.all(pred_masks[j] == 0) and score >= threshold:
+ segmentation[pred_masks[j] == 1] = current_segment_id
+ segments.append(
+ {
+ "id": current_segment_id,
+ "label_id": pred_classes[j].item(),
+ "was_fused": False,
+ "score": round(score, 6),
+ }
+ )
+ current_segment_id += 1
+ instance_maps.append(pred_masks[j])
+
+ # Return segmentation map in run-length encoding (RLE) format
+ if return_coco_annotation:
+ segmentation = convert_segmentation_to_rle(segmentation)
+
+ # Return a concatenated tensor of binary instance maps
+ if return_binary_maps and len(instance_maps) != 0:
+ segmentation = torch.stack(instance_maps, dim=0)
+
+ results.append({"segmentation": segmentation, "segments_info": segments})
+ return results
+
+ # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_panoptic_segmentation
+ def post_process_panoptic_segmentation(
+ self,
+ outputs,
+ threshold: float = 0.5,
+ mask_threshold: float = 0.5,
+ overlap_mask_area_threshold: float = 0.8,
+ label_ids_to_fuse: Optional[set[int]] = None,
+ target_sizes: Optional[list[tuple[int, int]]] = None,
+ ) -> list[dict]:
+ """
+ Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image panoptic segmentation
+ predictions. Only supports PyTorch.
+
+ Args:
+ outputs ([`MaskFormerForInstanceSegmentationOutput`]):
+ The outputs from [`MaskFormerForInstanceSegmentation`].
+ threshold (`float`, *optional*, defaults to 0.5):
+ The probability score threshold to keep predicted instance masks.
+ mask_threshold (`float`, *optional*, defaults to 0.5):
+ Threshold to use when turning the predicted masks into binary values.
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
+ instance mask.
+ label_ids_to_fuse (`Set[int]`, *optional*):
+ The labels in this state will have all their instances be fused together. For instance we could say
+ there can only be one sky in an image, but several persons, so the label ID for sky would be in that
+ set, but not the one for person.
+ target_sizes (`list[Tuple]`, *optional*):
+ List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested
+ final size (height, width) of each prediction in batch. If left to None, predictions will not be
+ resized.
+
+ Returns:
+ `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
+ - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set
+ to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized
+ to the corresponding `target_sizes` entry.
+ - **segments_info** -- A dictionary that contains additional information on each segment.
+ - **id** -- an integer representing the `segment_id`.
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
+ - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
+ Multiple instances of the same class / label were fused and assigned a single `segment_id`.
+ - **score** -- Prediction score of segment with `segment_id`.
+ """
+
+ if label_ids_to_fuse is None:
+ logger.warning("`label_ids_to_fuse` unset. No instance will be fused.")
+ label_ids_to_fuse = set()
+
+ class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
+ masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
+
+ batch_size = class_queries_logits.shape[0]
+ num_labels = class_queries_logits.shape[-1] - 1
+
+ mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
+
+ # Predicted label and score of each query (batch_size, num_queries)
+ pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
+
+ # Loop over items in batch size
+ results: list[dict[str, TensorType]] = []
+
+ for i in range(batch_size):
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
+ mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
+ )
+
+ # No mask found
+ if mask_probs_item.shape[0] <= 0:
+ height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
+ segmentation = torch.zeros((height, width)) - 1
+ results.append({"segmentation": segmentation, "segments_info": []})
+ continue
+
+ # Get segmentation map and segment information of batch item
+ target_size = target_sizes[i] if target_sizes is not None else None
+ segmentation, segments = compute_segments(
+ mask_probs=mask_probs_item,
+ pred_scores=pred_scores_item,
+ pred_labels=pred_labels_item,
+ mask_threshold=mask_threshold,
+ overlap_mask_area_threshold=overlap_mask_area_threshold,
+ label_ids_to_fuse=label_ids_to_fuse,
+ target_size=target_size,
+ )
+
+ results.append({"segmentation": segmentation, "segments_info": segments})
+ return results
+
+
+__all__ = ["MaskFormerImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/modeling_maskformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/modeling_maskformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..772f0a9fad0a47f073c41eb57e3e0442ec00aa23
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/modeling_maskformer.py
@@ -0,0 +1,1803 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc.s and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch MaskFormer model."""
+
+import math
+from dataclasses import dataclass
+from numbers import Number
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from torch import Tensor, nn
+
+from ...activations import ACT2FN
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithCrossAttentions
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import compile_compatible_method_lru_cache
+from ...utils import (
+ ModelOutput,
+ auto_docstring,
+ is_accelerate_available,
+ is_scipy_available,
+ logging,
+ requires_backends,
+)
+from ...utils.backbone_utils import load_backbone
+from ..detr import DetrConfig
+from .configuration_maskformer import MaskFormerConfig
+from .configuration_maskformer_swin import MaskFormerSwinConfig
+
+
+if is_accelerate_available():
+ from accelerate import PartialState
+ from accelerate.utils import reduce
+
+if is_scipy_available():
+ from scipy.optimize import linear_sum_assignment
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
+ namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
+ gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
+ """
+)
+# Copied from transformers.models.detr.modeling_detr.DetrDecoderOutput
+class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
+ r"""
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
+ used to compute the weighted average in the cross-attention heads.
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
+ Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
+ layernorm.
+ """
+
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ MaskFormer's pixel level module output. It returns both the last and (optionally) the hidden states from the
+ `encoder` and `decoder`. By default, the `encoder` is a MaskFormerSwin Transformer and the `decoder` is a Feature
+ Pyramid Network (FPN).
+
+ The `encoder_last_hidden_state` are referred on the paper as **images features**, while `decoder_last_hidden_state`
+ as **pixel embeddings**
+ """
+)
+class MaskFormerPixelLevelModuleOutput(ModelOutput):
+ r"""
+ encoder_last_hidden_state (`torch.FloatTensor` of shape`(batch_size, num_channels, height, width)`):
+ Last hidden states (final feature map) of the last stage of the encoder.
+ decoder_last_hidden_state (`torch.FloatTensor` of shape`(batch_size, num_channels, height, width)`):
+ Last hidden states (final feature map) of the last stage of the decoder.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at
+ the output of each stage.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at
+ the output of each stage.
+ """
+
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ decoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state
+ and (optionally) the hidden states.
+ """
+)
+class MaskFormerPixelDecoderOutput(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Last hidden states (final feature map) of the last stage of the model.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Class for outputs of [`MaskFormerModel`]. This class returns all the needed hidden states to compute the logits.
+ """
+)
+class MaskFormerModelOutput(ModelOutput):
+ r"""
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Last hidden states (final feature map) of the last stage of the encoder model (backbone).
+ pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Last hidden states (final feature map) of the last stage of the pixel decoder model (FPN).
+ transformer_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Last hidden states (final feature map) of the last stage of the transformer decoder model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder
+ model at the output of each stage.
+ pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel
+ decoder model at the output of each stage.
+ transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the
+ transformer decoder at the output of each stage.
+ hidden_states `tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` containing `encoder_hidden_states`, `pixel_decoder_hidden_states` and
+ `decoder_hidden_states`
+ """
+
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ pixel_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ transformer_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ pixel_decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ transformer_decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Class for outputs of [`MaskFormerForInstanceSegmentation`].
+
+ This output can be directly passed to [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or or
+ [`~MaskFormerImageProcessor.post_process_instance_segmentation`] or
+ [`~MaskFormerImageProcessor.post_process_panoptic_segmentation`] depending on the task. Please, see
+ [`~MaskFormerImageProcessor] for details regarding usage.
+ """
+)
+class MaskFormerForInstanceSegmentationOutput(ModelOutput):
+ r"""
+ loss (`torch.Tensor`, *optional*):
+ The computed loss, returned when labels are present.
+ class_queries_logits (`torch.FloatTensor`):
+ A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
+ query. Note the `+ 1` is needed because we incorporate the null class.
+ masks_queries_logits (`torch.FloatTensor`):
+ A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
+ query.
+ auxiliary_logits (`Dict[str, torch.FloatTensor]`, *optional*, returned when `output_auxiliary_logits=True`):
+ Dictionary containing auxiliary predictions for each decoder layer when auxiliary losses are enabled.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Last hidden states (final feature map) of the last stage of the encoder model (backbone).
+ pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Last hidden states (final feature map) of the last stage of the pixel decoder model (FPN).
+ transformer_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Last hidden states (final feature map) of the last stage of the transformer decoder model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder
+ model at the output of each stage.
+ pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel
+ decoder model at the output of each stage.
+ transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the transformer decoder at the output
+ of each stage.
+ hidden_states `tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` containing `encoder_hidden_states`, `pixel_decoder_hidden_states` and
+ `decoder_hidden_states`.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ class_queries_logits: Optional[torch.FloatTensor] = None
+ masks_queries_logits: Optional[torch.FloatTensor] = None
+ auxiliary_logits: Optional[torch.FloatTensor] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ pixel_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ transformer_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ pixel_decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ transformer_decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+def upsample_like(pixel_values: Tensor, like: Tensor, mode: str = "bilinear") -> Tensor:
+ """
+ An utility function that upsamples `pixel_values` to match the dimension of `like`.
+
+ Args:
+ pixel_values (`torch.Tensor`):
+ The tensor we wish to upsample.
+ like (`torch.Tensor`):
+ The tensor we wish to use as size target.
+ mode (str, *optional*, defaults to `"bilinear"`):
+ The interpolation mode.
+
+ Returns:
+ `torch.Tensor`: The upsampled tensor
+ """
+ _, _, height, width = like.shape
+ upsampled = nn.functional.interpolate(pixel_values, size=(height, width), mode=mode, align_corners=False)
+ return upsampled
+
+
+# refactored from original implementation
+def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor:
+ r"""
+ Compute the DICE loss, similar to generalized IOU for masks as follows:
+
+ $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$
+
+ In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow
+
+ $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$
+
+ Args:
+ inputs (`torch.Tensor`):
+ A tensor representing a mask.
+ labels (`torch.Tensor`):
+ A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ num_masks (`int`):
+ The number of masks present in the current batch, used for normalization.
+
+ Returns:
+ `torch.Tensor`: The computed loss.
+ """
+ probs = inputs.sigmoid().flatten(1)
+ numerator = 2 * (probs * labels).sum(-1)
+ denominator = probs.sum(-1) + labels.sum(-1)
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ loss = loss.sum() / num_masks
+ return loss
+
+
+# refactored from original implementation
+def sigmoid_focal_loss(
+ inputs: Tensor, labels: Tensor, num_masks: int, alpha: float = 0.25, gamma: float = 2
+) -> Tensor:
+ r"""
+ Focal loss proposed in [Focal Loss for Dense Object Detection](https://huggingface.co/papers/1708.02002) originally used in
+ RetinaNet. The loss is computed as follows:
+
+ $$ \mathcal{L}_{\text{focal loss} = -(1 - p_t)^{\gamma}\log{(p_t)} $$
+
+ where \\(CE(p_t) = -\log{(p_t)}}\\), CE is the standard Cross Entropy Loss
+
+ Please refer to equation (1,2,3) of the paper for a better understanding.
+
+ Args:
+ inputs (`torch.Tensor`):
+ A float tensor of arbitrary shape.
+ labels (`torch.Tensor`):
+ A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ num_masks (`int`):
+ The number of masks present in the current batch, used for normalization.
+ alpha (float, *optional*, defaults to 0.25):
+ Weighting factor in range (0,1) to balance positive vs negative examples.
+ gamma (float, *optional*, defaults to 2.0):
+ Exponent of the modulating factor \\(1 - p_t\\) to balance easy vs hard examples.
+
+ Returns:
+ `torch.Tensor`: The computed loss.
+ """
+ criterion = nn.BCEWithLogitsLoss(reduction="none")
+ probs = inputs.sigmoid()
+ cross_entropy_loss = criterion(inputs, labels)
+ p_t = probs * labels + (1 - probs) * (1 - labels)
+ loss = cross_entropy_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0:
+ alpha_t = alpha * labels + (1 - alpha) * (1 - labels)
+ loss = alpha_t * loss
+
+ loss = loss.mean(1).sum() / num_masks
+ return loss
+
+
+# refactored from original implementation
+def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
+ """
+ A pair wise version of the dice loss, see `dice_loss` for usage.
+
+ Args:
+ inputs (`torch.Tensor`):
+ A tensor representing a mask
+ labels (`torch.Tensor`):
+ A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+
+ Returns:
+ `torch.Tensor`: The computed loss between each pairs.
+ """
+ inputs = inputs.sigmoid().flatten(1)
+ numerator = 2 * torch.matmul(inputs, labels.T)
+ # using broadcasting to get a [num_queries, NUM_CLASSES] matrix
+ denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ return loss
+
+
+# refactored from original implementation
+def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float = 0.25, gamma: float = 2.0) -> Tensor:
+ r"""
+ A pair wise version of the focal loss, see `sigmoid_focal_loss` for usage.
+
+ Args:
+ inputs (`torch.Tensor`):
+ A tensor representing a mask.
+ labels (`torch.Tensor`):
+ A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ alpha (float, *optional*, defaults to 0.25):
+ Weighting factor in range (0,1) to balance positive vs negative examples.
+ gamma (float, *optional*, defaults to 2.0):
+ Exponent of the modulating factor \\(1 - p_t\\) to balance easy vs hard examples.
+
+ Returns:
+ `torch.Tensor`: The computed loss between each pairs.
+ """
+ if alpha < 0:
+ raise ValueError("alpha must be positive")
+
+ height_and_width = inputs.shape[1]
+
+ criterion = nn.BCEWithLogitsLoss(reduction="none")
+ prob = inputs.sigmoid()
+ cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
+ focal_pos = ((1 - prob) ** gamma) * cross_entropy_loss_pos
+ focal_pos *= alpha
+
+ cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
+
+ focal_neg = (prob**gamma) * cross_entropy_loss_neg
+ focal_neg *= 1 - alpha
+
+ loss = torch.matmul(focal_pos, labels.T) + torch.matmul(focal_neg, (1 - labels).T)
+
+ return loss / height_and_width
+
+
+# Copied from transformers.models.detr.modeling_detr.DetrAttention
+class DetrAttention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper.
+
+ Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ if self.head_dim * num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]):
+ return tensor if object_queries is None else tensor + object_queries
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ object_queries: Optional[torch.Tensor] = None,
+ key_value_states: Optional[torch.Tensor] = None,
+ spatial_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ batch_size, target_len, embed_dim = hidden_states.size()
+
+ # add position embeddings to the hidden states before projecting to queries and keys
+ if object_queries is not None:
+ hidden_states_original = hidden_states
+ hidden_states = self.with_pos_embed(hidden_states, object_queries)
+
+ # add key-value position embeddings to the key value states
+ if spatial_position_embeddings is not None:
+ key_value_states_original = key_value_states
+ key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
+ value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
+ value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
+
+ proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ source_len = key_states.size(1)
+
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
+ raise ValueError(
+ f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (batch_size, 1, target_len, source_len):
+ raise ValueError(
+ f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
+ f" {attention_mask.size()}"
+ )
+ if attention_mask.dtype == torch.bool:
+ attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
+ attention_mask, -torch.inf
+ )
+ attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
+ attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
+ attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+# Copied from transformers.models.detr.modeling_detr.DetrDecoderLayer
+class DetrDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DetrConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = DetrAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.encoder_attn = DetrAttention(
+ self.embed_dim,
+ config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ object_queries: Optional[torch.Tensor] = None,
+ query_position_embeddings: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ):
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
+ values.
+ object_queries (`torch.FloatTensor`, *optional*):
+ object_queries that are added to the hidden states
+ in the cross-attention layer.
+ query_position_embeddings (`torch.FloatTensor`, *optional*):
+ position embeddings that are added to the queries and keys
+ in the self-attention layer.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
+ values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ object_queries=query_position_embeddings,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Cross-Attention Block
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+
+ hidden_states, cross_attn_weights = self.encoder_attn(
+ hidden_states=hidden_states,
+ object_queries=query_position_embeddings,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ spatial_position_embeddings=object_queries,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ return outputs
+
+
+class DetrDecoder(nn.Module):
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`].
+
+ The decoder updates the query embeddings through multiple self-attention and cross-attention layers.
+
+ Some small tweaks for DETR:
+
+ - object_queries and query_position_embeddings are added to the forward pass.
+ - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
+
+ Args:
+ config: DetrConfig
+ """
+
+ def __init__(self, config: DetrConfig):
+ super().__init__()
+ self.config = config
+ self.dropout = config.dropout
+ self.layerdrop = config.decoder_layerdrop
+
+ self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
+ # in DETR, the decoder uses layernorm after the last decoder layer output
+ self.layernorm = nn.LayerNorm(config.d_model)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ inputs_embeds=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ object_queries=None,
+ query_position_embeddings=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ The query embeddings that are passed into the decoder.
+
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`:
+
+ - 1 for queries that are **not masked**,
+ - 0 for queries that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
+ in `[0, 1]`:
+
+ - 1 for pixels that are real (i.e. **not masked**),
+ - 0 for pixels that are padding (i.e. **masked**).
+
+ object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Position embeddings that are added to the queries and keys in each cross-attention layer.
+ query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
+ , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ input_shape = inputs_embeds.size()[:-1]
+
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ # optional intermediate hidden states
+ intermediate = () if self.config.auxiliary_loss else None
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ None, # attention_mask
+ object_queries,
+ query_position_embeddings,
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if self.config.auxiliary_loss:
+ hidden_states = self.layernorm(hidden_states)
+ intermediate += (hidden_states,)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ # finally, apply layernorm
+ hidden_states = self.layernorm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # stack intermediate decoder activations
+ if self.config.auxiliary_loss:
+ intermediate = torch.stack(intermediate)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate]
+ if v is not None
+ )
+ return DetrDecoderOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ intermediate_hidden_states=intermediate,
+ )
+
+
+# refactored from original implementation
+class MaskFormerHungarianMatcher(nn.Module):
+ """This class computes an assignment between the labels and the predictions of the network.
+
+ For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more
+ predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are
+ un-matched (and thus treated as non-objects).
+ """
+
+ def __init__(self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0):
+ """Creates the matcher
+
+ Params:
+ cost_class (float, *optional*, defaults to 1.0):
+ This is the relative weight of the classification error in the matching cost.
+ cost_mask (float, *optional*, defaults to 1.0):
+ This is the relative weight of the focal loss of the binary mask in the matching cost.
+ cost_dice (float, *optional*, defaults to 1.0):
+ This is the relative weight of the dice loss of the binary mask in the matching cost
+ """
+ super().__init__()
+ if cost_class == 0 and cost_mask == 0 and cost_dice == 0:
+ raise ValueError("All costs can't be 0")
+ self.cost_class = cost_class
+ self.cost_mask = cost_mask
+ self.cost_dice = cost_dice
+
+ @torch.no_grad()
+ def forward(self, masks_queries_logits, class_queries_logits, mask_labels, class_labels) -> list[tuple[Tensor]]:
+ """Performs the matching
+
+ Params:
+ masks_queries_logits (`torch.Tensor`):
+ A tensor` of dim `batch_size, num_queries, num_labels` with the
+ classification logits.
+ class_queries_logits (`torch.Tensor`):
+ A tensor` of dim `batch_size, num_queries, height, width` with the
+ predicted masks.
+
+ class_labels (`torch.Tensor`):
+ A tensor` of dim `num_target_boxes` (where num_target_boxes is the number
+ of ground-truth objects in the target) containing the class labels.
+ mask_labels (`torch.Tensor`):
+ A tensor` of dim `num_target_boxes, height, width` containing the target
+ masks.
+
+ Returns:
+ `list[tuple[Tensor]]`: A list of size batch_size, containing tuples of (index_i, index_j) where:
+ - index_i is the indices of the selected predictions (in order)
+ - index_j is the indices of the corresponding selected labels (in order)
+ For each batch element, it holds:
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
+ """
+ indices: list[tuple[np.array]] = []
+
+ preds_masks = masks_queries_logits
+ preds_probs = class_queries_logits
+ # iterate through batch size
+ for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels):
+ # downsample the target mask, save memory
+ target_mask = nn.functional.interpolate(target_mask[:, None], size=pred_mask.shape[-2:], mode="nearest")
+ pred_probs = pred_probs.softmax(-1)
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+ # but approximate it in 1 - proba[target class].
+ # The 1 is a constant that doesn't change the matching, it can be omitted.
+ cost_class = -pred_probs[:, labels]
+ # flatten spatial dimension "q h w -> q (h w)"
+ pred_mask_flat = pred_mask.flatten(1) # [num_queries, height*width]
+ # same for target_mask "c h w -> c (h w)"
+ target_mask_flat = target_mask[:, 0].flatten(1) # [num_total_labels, height*width]
+ # compute the focal loss between each mask pairs -> shape (num_queries, num_labels)
+ cost_mask = pair_wise_sigmoid_focal_loss(pred_mask_flat, target_mask_flat)
+ # Compute the dice loss between each mask pairs -> shape (num_queries, num_labels)
+ cost_dice = pair_wise_dice_loss(pred_mask_flat, target_mask_flat)
+ # final cost matrix
+ cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
+ # do the assignment using the hungarian algorithm in scipy
+ assigned_indices: tuple[np.array] = linear_sum_assignment(cost_matrix.cpu())
+ indices.append(assigned_indices)
+
+ # It could be stacked in one tensor
+ matched_indices = [
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices
+ ]
+ return matched_indices
+
+ def __repr__(self):
+ head = "Matcher " + self.__class__.__name__
+ body = [
+ f"cost_class: {self.cost_class}",
+ f"cost_mask: {self.cost_mask}",
+ f"cost_dice: {self.cost_dice}",
+ ]
+ _repr_indent = 4
+ lines = [head] + [" " * _repr_indent + line for line in body]
+ return "\n".join(lines)
+
+
+# copied and adapted from original implementation
+class MaskFormerLoss(nn.Module):
+ def __init__(
+ self,
+ num_labels: int,
+ matcher: MaskFormerHungarianMatcher,
+ weight_dict: dict[str, float],
+ eos_coef: float,
+ ):
+ """
+ The MaskFormer Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we compute
+ hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair of
+ matched ground-truth / prediction (supervise class and mask)
+
+ Args:
+ num_labels (`int`):
+ The number of classes.
+ matcher (`MaskFormerHungarianMatcher`):
+ A torch module that computes the assignments between the predictions and labels.
+ weight_dict (`dict[str, float]`):
+ A dictionary of weights to be applied to the different losses.
+ eos_coef (`float`):
+ Weight to apply to the null class.
+ """
+
+ super().__init__()
+ requires_backends(self, ["scipy"])
+ self.num_labels = num_labels
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.eos_coef = eos_coef
+ empty_weight = torch.ones(self.num_labels + 1)
+ empty_weight[-1] = self.eos_coef
+ self.register_buffer("empty_weight", empty_weight)
+
+ def _max_by_axis(self, the_list: list[list[int]]) -> list[int]:
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+ def _pad_images_to_max_in_batch(self, tensors: list[Tensor]) -> tuple[Tensor, Tensor]:
+ # get the maximum size in the batch
+ max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
+ batch_size = len(tensors)
+ # compute finel size
+ batch_shape = [batch_size] + max_size
+ b, _, h, w = batch_shape
+ # get metadata
+ dtype = tensors[0].dtype
+ device = tensors[0].device
+ padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
+ padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ # pad the tensors to the size of the biggest one
+ for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
+ padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
+ padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
+
+ return padded_tensors, padding_masks
+
+ def loss_labels(
+ self, class_queries_logits: Tensor, class_labels: list[Tensor], indices: tuple[np.array]
+ ) -> dict[str, Tensor]:
+ """Compute the losses related to the labels using cross entropy.
+
+ Args:
+ class_queries_logits (`torch.Tensor`):
+ A tensor of shape `batch_size, num_queries, num_labels`
+ class_labels (`list[torch.Tensor]`):
+ List of class labels of shape `(labels)`.
+ indices (`tuple[np.array])`:
+ The indices computed by the Hungarian matcher.
+
+ Returns:
+ `dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
+ - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
+ """
+
+ pred_logits = class_queries_logits
+ batch_size, num_queries, _ = pred_logits.shape
+ criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
+ idx = self._get_predictions_permutation_indices(indices)
+ # shape = (batch_size, num_queries)
+ target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)])
+ # shape = (batch_size, num_queries)
+ target_classes = torch.full(
+ (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
+ )
+ target_classes[idx] = target_classes_o
+ # target_classes is a (batch_size, num_labels, num_queries), we need to permute pred_logits "b q c -> b c q"
+ pred_logits_transposed = pred_logits.transpose(1, 2)
+ loss_ce = criterion(pred_logits_transposed, target_classes)
+ losses = {"loss_cross_entropy": loss_ce}
+ return losses
+
+ def loss_masks(
+ self, masks_queries_logits: Tensor, mask_labels: list[Tensor], indices: tuple[np.array], num_masks: int
+ ) -> dict[str, Tensor]:
+ """Compute the losses related to the masks using focal and dice loss.
+
+ Args:
+ masks_queries_logits (`torch.Tensor`):
+ A tensor of shape `batch_size, num_queries, height, width`
+ mask_labels (`torch.Tensor`):
+ List of mask labels of shape `(labels, height, width)`.
+ indices (`tuple[np.array])`:
+ The indices computed by the Hungarian matcher.
+ num_masks (`int)`:
+ The number of masks, used for normalization.
+
+ Returns:
+ `dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys:
+ - **loss_mask** -- The loss computed using sigmoid focal loss on the predicted and ground truth masks.
+ - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth
+ masks.
+ """
+ src_idx = self._get_predictions_permutation_indices(indices)
+ tgt_idx = self._get_targets_permutation_indices(indices)
+ # shape (batch_size * num_queries, height, width)
+ pred_masks = masks_queries_logits[src_idx]
+ # shape (batch_size, num_queries, height, width)
+ # pad all and stack the targets to the num_labels dimension
+ target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
+ target_masks = target_masks[tgt_idx]
+ # upsample predictions to the target size, we have to add one dim to use interpolate
+ pred_masks = nn.functional.interpolate(
+ pred_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
+ )
+ pred_masks = pred_masks[:, 0].flatten(1)
+
+ target_masks = target_masks.flatten(1)
+ losses = {
+ "loss_mask": sigmoid_focal_loss(pred_masks, target_masks, num_masks),
+ "loss_dice": dice_loss(pred_masks, target_masks, num_masks),
+ }
+ return losses
+
+ def _get_predictions_permutation_indices(self, indices):
+ # permute predictions following indices
+ batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+ predictions_indices = torch.cat([src for (src, _) in indices])
+ return batch_indices, predictions_indices
+
+ def _get_targets_permutation_indices(self, indices):
+ # permute labels following indices
+ batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+ target_indices = torch.cat([tgt for (_, tgt) in indices])
+ return batch_indices, target_indices
+
+ def forward(
+ self,
+ masks_queries_logits: Tensor,
+ class_queries_logits: Tensor,
+ mask_labels: list[Tensor],
+ class_labels: list[Tensor],
+ auxiliary_predictions: Optional[dict[str, Tensor]] = None,
+ ) -> dict[str, Tensor]:
+ """
+ This performs the loss computation.
+
+ Args:
+ masks_queries_logits (`torch.Tensor`):
+ A tensor of shape `batch_size, num_queries, height, width`
+ class_queries_logits (`torch.Tensor`):
+ A tensor of shape `batch_size, num_queries, num_labels`
+ mask_labels (`torch.Tensor`):
+ List of mask labels of shape `(labels, height, width)`.
+ class_labels (`list[torch.Tensor]`):
+ List of class labels of shape `(labels)`.
+ auxiliary_predictions (`dict[str, torch.Tensor]`, *optional*):
+ if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], then it contains the logits from the
+ inner layers of the Detr's Decoder.
+
+ Returns:
+ `dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys:
+ - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
+ - **loss_mask** -- The loss computed using sigmoid focal loss on the predicted and ground truth masks.
+ - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth
+ masks.
+ if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], the dictionary contains additional losses
+ for each auxiliary predictions.
+ """
+
+ # retrieve the matching between the outputs of the last layer and the labels
+ indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
+ # compute the average number of target masks for normalization purposes
+ num_masks: Number = self.get_num_masks(class_labels, device=class_labels[0].device)
+ # get all the losses
+ losses: dict[str, Tensor] = {
+ **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
+ **self.loss_labels(class_queries_logits, class_labels, indices),
+ }
+ # in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+ if auxiliary_predictions is not None:
+ for idx, aux_outputs in enumerate(auxiliary_predictions):
+ masks_queries_logits = aux_outputs["masks_queries_logits"]
+ class_queries_logits = aux_outputs["class_queries_logits"]
+ loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
+ loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()}
+ losses.update(loss_dict)
+
+ return losses
+
+ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
+ """
+ Computes the average number of target masks across the batch, for normalization purposes.
+ """
+ num_masks = sum(len(classes) for classes in class_labels)
+ num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
+ world_size = 1
+ if is_accelerate_available():
+ if PartialState._shared_state != {}:
+ num_masks = reduce(num_masks)
+ world_size = PartialState().num_processes
+
+ num_masks = torch.clamp(num_masks / world_size, min=1)
+ return num_masks
+
+
+class MaskFormerFPNConvLayer(nn.Module):
+ def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, padding: int = 1):
+ """
+ A basic module that executes conv - norm - in sequence used in MaskFormer.
+
+ Args:
+ in_features (`int`):
+ The number of input features (channels).
+ out_features (`int`):
+ The number of outputs features (channels).
+ """
+ super().__init__()
+ self.layers = [
+ nn.Conv2d(in_features, out_features, kernel_size=kernel_size, padding=padding, bias=False),
+ nn.GroupNorm(32, out_features),
+ nn.ReLU(inplace=True),
+ ]
+ for i, layer in enumerate(self.layers):
+ # Provide backwards compatibility from when the class inherited from nn.Sequential
+ # In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
+ # In nn.Module subclasses they derived from the instance attribute they are assigned to e.g.
+ # self.my_layer_name = Layer()
+ # We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register
+ # explicitly
+ self.add_module(str(i), layer)
+
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = input
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
+
+
+class MaskFormerFPNLayer(nn.Module):
+ def __init__(self, in_features: int, lateral_features: int):
+ """
+ A Feature Pyramid Network Layer (FPN) layer. It creates a feature map by aggregating features from the previous
+ and backbone layer. Due to the spatial mismatch, the tensor coming from the previous layer is upsampled.
+
+ Args:
+ in_features (`int`):
+ The number of input features (channels).
+ lateral_features (`int`):
+ The number of lateral features (channels).
+ """
+ super().__init__()
+ self.proj = nn.Sequential(
+ nn.Conv2d(lateral_features, in_features, kernel_size=1, padding=0, bias=False),
+ nn.GroupNorm(32, in_features),
+ )
+
+ self.block = MaskFormerFPNConvLayer(in_features, in_features)
+
+ def forward(self, down: Tensor, left: Tensor) -> Tensor:
+ left = self.proj(left)
+ down = nn.functional.interpolate(down, size=left.shape[-2:], mode="nearest")
+ down += left
+ down = self.block(down)
+ return down
+
+
+class MaskFormerFPNModel(nn.Module):
+ def __init__(self, in_features: int, lateral_widths: list[int], feature_size: int = 256):
+ """
+ Feature Pyramid Network, given an input tensor and a set of feature map of different feature/spatial size, it
+ creates a list of feature maps with the same feature size.
+
+ Args:
+ in_features (`int`):
+ The number of input features (channels).
+ lateral_widths (`list[int]`):
+ A list with the features (channels) size of each lateral connection.
+ feature_size (int, *optional*, defaults to 256):
+ The features (channels) of the resulting feature maps.
+ """
+ super().__init__()
+ self.stem = MaskFormerFPNConvLayer(in_features, feature_size)
+ self.layers = nn.Sequential(
+ *[MaskFormerFPNLayer(feature_size, lateral_width) for lateral_width in lateral_widths[::-1]]
+ )
+
+ def forward(self, features: list[Tensor]) -> list[Tensor]:
+ fpn_features = []
+ last_feature = features[-1]
+ other_features = features[:-1]
+ output = self.stem(last_feature)
+ for layer, left in zip(self.layers, other_features[::-1]):
+ output = layer(output, left)
+ fpn_features.append(output)
+ return fpn_features
+
+
+class MaskFormerPixelDecoder(nn.Module):
+ def __init__(self, *args, feature_size: int = 256, mask_feature_size: int = 256, **kwargs):
+ r"""
+ Pixel Decoder Module proposed in [Per-Pixel Classification is Not All You Need for Semantic
+ Segmentation](https://huggingface.co/papers/2107.06278). It first runs the backbone's features into a Feature Pyramid
+ Network creating a list of feature maps. Then, it projects the last one to the correct `mask_size`.
+
+ Args:
+ feature_size (`int`, *optional*, defaults to 256):
+ The feature size (channel dimension) of the FPN feature maps.
+ mask_feature_size (`int`, *optional*, defaults to 256):
+ The features (channels) of the target masks size \\(C_{\epsilon}\\) in the paper.
+ """
+ super().__init__()
+
+ self.fpn = MaskFormerFPNModel(*args, feature_size=feature_size, **kwargs)
+ self.mask_projection = nn.Conv2d(feature_size, mask_feature_size, kernel_size=3, padding=1)
+
+ def forward(
+ self, features: list[Tensor], output_hidden_states: bool = False, return_dict: bool = True
+ ) -> MaskFormerPixelDecoderOutput:
+ fpn_features = self.fpn(features)
+ # we use the last feature map
+ last_feature_projected = self.mask_projection(fpn_features[-1])
+
+ if not return_dict:
+ return (last_feature_projected, tuple(fpn_features)) if output_hidden_states else (last_feature_projected,)
+
+ return MaskFormerPixelDecoderOutput(
+ last_hidden_state=last_feature_projected, hidden_states=tuple(fpn_features) if output_hidden_states else ()
+ )
+
+
+# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
+class MaskFormerSinePositionEmbedding(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
+ need paper, generalized to work on images.
+ """
+
+ def __init__(
+ self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
+ ):
+ super().__init__()
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ self.scale = 2 * math.pi if scale is None else scale
+
+ @compile_compatible_method_lru_cache(maxsize=1)
+ def forward(
+ self,
+ shape: torch.Size,
+ device: Union[torch.device, str],
+ dtype: torch.dtype,
+ mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ if mask is None:
+ mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
+ not_mask = (~mask).to(dtype)
+ y_embed = not_mask.cumsum(1)
+ x_embed = not_mask.cumsum(2)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class PredictionBlock(nn.Module):
+ def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None:
+ super().__init__()
+ self.layers = [nn.Linear(in_dim, out_dim), activation]
+ # Maintain submodule indexing as if part of a Sequential block
+ for i, layer in enumerate(self.layers):
+ self.add_module(str(i), layer)
+
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = input
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
+
+
+class MaskformerMLPPredictionHead(nn.Module):
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3):
+ """
+ A classic Multi Layer Perceptron (MLP).
+
+ Args:
+ input_dim (`int`):
+ The input dimensions.
+ hidden_dim (`int`):
+ The hidden dimensions.
+ output_dim (`int`):
+ The output dimensions.
+ num_layers (int, *optional*, defaults to 3):
+ The number of layers.
+ """
+ super().__init__()
+ in_dims = [input_dim] + [hidden_dim] * (num_layers - 1)
+ out_dims = [hidden_dim] * (num_layers - 1) + [output_dim]
+
+ self.layers = []
+ for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
+ activation = nn.ReLU() if i < num_layers - 1 else nn.Identity()
+ layer = PredictionBlock(in_dim, out_dim, activation=activation)
+ self.layers.append(layer)
+ # Provide backwards compatibility from when the class inherited from nn.Sequential
+ # In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
+ # In nn.Module subclasses they derived from the instance attribute they are assigned to e.g.
+ # self.my_layer_name = Layer()
+ # We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register
+ # explicitly
+ self.add_module(str(i), layer)
+
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = input
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
+
+
+class MaskFormerPixelLevelModule(nn.Module):
+ def __init__(self, config: MaskFormerConfig):
+ """
+ Pixel Level Module proposed in [Per-Pixel Classification is Not All You Need for Semantic
+ Segmentation](https://huggingface.co/papers/2107.06278). It runs the input image through a backbone and a pixel
+ decoder, generating an image feature map and pixel embeddings.
+
+ Args:
+ config ([`MaskFormerConfig`]):
+ The configuration used to instantiate this model.
+ """
+ super().__init__()
+ if getattr(config, "backbone_config") is not None and config.backbone_config.model_type == "swin":
+ # for backwards compatibility
+ backbone_config = config.backbone_config
+ backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict())
+ backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"]
+ config.backbone_config = backbone_config
+ self.encoder = load_backbone(config)
+
+ feature_channels = self.encoder.channels
+ self.decoder = MaskFormerPixelDecoder(
+ in_features=feature_channels[-1],
+ feature_size=config.fpn_feature_size,
+ mask_feature_size=config.mask_feature_size,
+ lateral_widths=feature_channels[:-1],
+ )
+
+ def forward(
+ self, pixel_values: Tensor, output_hidden_states: bool = False, return_dict: bool = True
+ ) -> MaskFormerPixelLevelModuleOutput:
+ features = self.encoder(pixel_values).feature_maps
+ decoder_output = self.decoder(features, output_hidden_states, return_dict=return_dict)
+
+ if not return_dict:
+ last_hidden_state = decoder_output[0]
+ outputs = (features[-1], last_hidden_state)
+ if output_hidden_states:
+ hidden_states = decoder_output[1]
+ outputs = outputs + (tuple(features),) + (hidden_states,)
+ return outputs
+
+ return MaskFormerPixelLevelModuleOutput(
+ # the last feature is actually the output from the last layer
+ encoder_last_hidden_state=features[-1],
+ decoder_last_hidden_state=decoder_output.last_hidden_state,
+ encoder_hidden_states=tuple(features) if output_hidden_states else (),
+ decoder_hidden_states=decoder_output.hidden_states if output_hidden_states else (),
+ )
+
+
+class MaskFormerTransformerModule(nn.Module):
+ """
+ The MaskFormer's transformer module.
+ """
+
+ def __init__(self, in_features: int, config: MaskFormerConfig):
+ super().__init__()
+ hidden_size = config.decoder_config.hidden_size
+ should_project = in_features != hidden_size
+ self.position_embedder = MaskFormerSinePositionEmbedding(num_pos_feats=hidden_size // 2, normalize=True)
+ self.queries_embedder = nn.Embedding(config.decoder_config.num_queries, hidden_size)
+ self.input_projection = nn.Conv2d(in_features, hidden_size, kernel_size=1) if should_project else None
+ self.decoder = DetrDecoder(config=config.decoder_config)
+
+ def forward(
+ self,
+ image_features: Tensor,
+ output_hidden_states: bool = False,
+ output_attentions: bool = False,
+ return_dict: Optional[bool] = None,
+ ) -> DetrDecoderOutput:
+ if self.input_projection is not None:
+ image_features = self.input_projection(image_features)
+ object_queries = self.position_embedder(image_features.shape, image_features.device, image_features.dtype)
+ # repeat the queries "q c -> b q c"
+ batch_size = image_features.shape[0]
+ queries_embeddings = self.queries_embedder.weight.unsqueeze(0).repeat(batch_size, 1, 1)
+ inputs_embeds = torch.zeros_like(queries_embeddings, requires_grad=self.training)
+
+ # torch.export.export does no support requires_grad
+ if self.training:
+ inputs_embeds.requires_grad_(True)
+
+ batch_size, num_channels, height, width = image_features.shape
+ # rearrange both image_features and object_queries "b c h w -> b (h w) c"
+ image_features = image_features.view(batch_size, num_channels, height * width).permute(0, 2, 1)
+ object_queries = object_queries.view(batch_size, num_channels, height * width).permute(0, 2, 1)
+
+ decoder_output: DetrDecoderOutput = self.decoder(
+ inputs_embeds=inputs_embeds,
+ attention_mask=None,
+ encoder_hidden_states=image_features,
+ encoder_attention_mask=None,
+ object_queries=object_queries,
+ query_position_embeddings=queries_embeddings,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ return decoder_output
+
+
+@auto_docstring
+class MaskFormerPreTrainedModel(PreTrainedModel):
+ config: MaskFormerConfig
+ base_model_prefix = "model"
+ main_input_name = "pixel_values"
+
+ def _init_weights(self, module: nn.Module):
+ xavier_std = self.config.init_xavier_std
+ std = self.config.init_std
+ if isinstance(module, MaskFormerTransformerModule):
+ if module.input_projection is not None:
+ nn.init.xavier_uniform_(module.input_projection.weight, gain=xavier_std)
+ nn.init.constant_(module.input_projection.bias, 0)
+ # FPN
+ elif isinstance(module, MaskFormerFPNModel):
+ nn.init.xavier_uniform_(module.stem.get_submodule("0").weight, gain=xavier_std)
+
+ elif isinstance(module, MaskFormerFPNLayer):
+ nn.init.xavier_uniform_(module.proj[0].weight, gain=xavier_std)
+
+ elif isinstance(module, MaskFormerFPNConvLayer):
+ nn.init.xavier_uniform_(module.get_submodule("0").weight, gain=xavier_std)
+ # The MLP head
+ elif isinstance(module, MaskformerMLPPredictionHead):
+ # I was not able to find the correct initializer in the original implementation
+ # we'll use xavier
+ for submodule in module.modules():
+ if isinstance(submodule, nn.Linear):
+ nn.init.xavier_uniform_(submodule.weight, gain=xavier_std)
+ nn.init.constant_(submodule.bias, 0)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ # copied from DETR
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+@auto_docstring
+class MaskFormerModel(MaskFormerPreTrainedModel):
+ def __init__(self, config: MaskFormerConfig):
+ super().__init__(config)
+ self.pixel_level_module = MaskFormerPixelLevelModule(config)
+ self.transformer_module = MaskFormerTransformerModule(
+ in_features=self.pixel_level_module.encoder.channels[-1], config=config
+ )
+
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Tensor,
+ pixel_mask: Optional[Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> MaskFormerModelOutput:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, MaskFormerModel
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> # load MaskFormer fine-tuned on ADE20k semantic segmentation
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/maskformer-swin-base-ade")
+ >>> model = MaskFormerModel.from_pretrained("facebook/maskformer-swin-base-ade")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = image_processor(image, return_tensors="pt")
+
+ >>> # forward pass
+ >>> outputs = model(**inputs)
+
+ >>> # the decoder of MaskFormer outputs hidden states of shape (batch_size, num_queries, hidden_size)
+ >>> transformer_decoder_last_hidden_state = outputs.transformer_decoder_last_hidden_state
+ >>> list(transformer_decoder_last_hidden_state.shape)
+ [1, 100, 256]
+ ```"""
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size, _, height, width = pixel_values.shape
+
+ if pixel_mask is None:
+ pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
+
+ pixel_level_module_output = self.pixel_level_module(
+ pixel_values, output_hidden_states, return_dict=return_dict
+ )
+ image_features = pixel_level_module_output[0]
+ pixel_embeddings = pixel_level_module_output[1]
+
+ transformer_module_output = self.transformer_module(image_features, output_hidden_states, output_attentions)
+ queries = transformer_module_output.last_hidden_state
+
+ encoder_hidden_states = None
+ pixel_decoder_hidden_states = None
+ transformer_decoder_hidden_states = None
+ hidden_states = None
+
+ if output_hidden_states:
+ encoder_hidden_states = pixel_level_module_output[2]
+ pixel_decoder_hidden_states = pixel_level_module_output[3]
+ transformer_decoder_hidden_states = transformer_module_output[1]
+ hidden_states = encoder_hidden_states + pixel_decoder_hidden_states + transformer_decoder_hidden_states
+
+ output = MaskFormerModelOutput(
+ encoder_last_hidden_state=image_features,
+ pixel_decoder_last_hidden_state=pixel_embeddings,
+ transformer_decoder_last_hidden_state=queries,
+ encoder_hidden_states=encoder_hidden_states,
+ pixel_decoder_hidden_states=pixel_decoder_hidden_states,
+ transformer_decoder_hidden_states=transformer_decoder_hidden_states,
+ hidden_states=hidden_states,
+ attentions=transformer_module_output.attentions,
+ )
+
+ if not return_dict:
+ output = tuple(v for v in output.values())
+
+ return output
+
+
+class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
+ def __init__(self, config: MaskFormerConfig):
+ super().__init__(config)
+ self.model = MaskFormerModel(config)
+ hidden_size = config.decoder_config.hidden_size
+ # + 1 because we add the "null" class
+ self.class_predictor = nn.Linear(hidden_size, config.num_labels + 1)
+ self.mask_embedder = MaskformerMLPPredictionHead(hidden_size, hidden_size, config.mask_feature_size)
+
+ self.matcher = MaskFormerHungarianMatcher(
+ cost_class=1.0, cost_dice=config.dice_weight, cost_mask=config.mask_weight
+ )
+
+ self.weight_dict: dict[str, float] = {
+ "loss_cross_entropy": config.cross_entropy_weight,
+ "loss_mask": config.mask_weight,
+ "loss_dice": config.dice_weight,
+ }
+
+ self.criterion = MaskFormerLoss(
+ config.num_labels,
+ matcher=self.matcher,
+ weight_dict=self.weight_dict,
+ eos_coef=config.no_object_weight,
+ )
+
+ self.post_init()
+
+ def get_loss_dict(
+ self,
+ masks_queries_logits: Tensor,
+ class_queries_logits: Tensor,
+ mask_labels: Tensor,
+ class_labels: Tensor,
+ auxiliary_logits: dict[str, Tensor],
+ ) -> dict[str, Tensor]:
+ loss_dict: dict[str, Tensor] = self.criterion(
+ masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits
+ )
+ # weight each loss by `self.weight_dict[]` including auxiliary losses
+ for key, weight in self.weight_dict.items():
+ for loss_key, loss in loss_dict.items():
+ if key in loss_key:
+ loss *= weight
+
+ return loss_dict
+
+ def get_loss(self, loss_dict: dict[str, Tensor]) -> Tensor:
+ return sum(loss_dict.values())
+
+ def get_logits(self, outputs: MaskFormerModelOutput) -> tuple[Tensor, Tensor, dict[str, Tensor]]:
+ pixel_embeddings = outputs.pixel_decoder_last_hidden_state
+ # get the auxiliary predictions (one for each decoder's layer)
+ auxiliary_logits: list[str, Tensor] = []
+
+ # This code is a little bit cumbersome, an improvement can be to return a list of predictions. If we have auxiliary loss then we are going to return more than one element in the list
+ if self.config.use_auxiliary_loss:
+ stacked_transformer_decoder_outputs = torch.stack(outputs.transformer_decoder_hidden_states)
+ classes = self.class_predictor(stacked_transformer_decoder_outputs)
+ class_queries_logits = classes[-1]
+ # get the masks
+ mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs)
+ binaries_masks = torch.einsum("lbqc, bchw -> lbqhw", mask_embeddings, pixel_embeddings)
+
+ masks_queries_logits = binaries_masks[-1]
+ # go til [:-1] because the last one is always used
+ for aux_binary_masks, aux_classes in zip(binaries_masks[:-1], classes[:-1]):
+ auxiliary_logits.append(
+ {"masks_queries_logits": aux_binary_masks, "class_queries_logits": aux_classes}
+ )
+
+ else:
+ transformer_decoder_hidden_states = outputs.transformer_decoder_last_hidden_state
+ classes = self.class_predictor(transformer_decoder_hidden_states)
+ class_queries_logits = classes
+ # get the masks
+ mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states)
+ # sum up over the channels
+ masks_queries_logits = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
+
+ return class_queries_logits, masks_queries_logits, auxiliary_logits
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Tensor,
+ mask_labels: Optional[list[Tensor]] = None,
+ class_labels: Optional[list[Tensor]] = None,
+ pixel_mask: Optional[Tensor] = None,
+ output_auxiliary_logits: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> MaskFormerForInstanceSegmentationOutput:
+ r"""
+ mask_labels (`list[torch.Tensor]`, *optional*):
+ List of mask labels of shape `(num_labels, height, width)` to be fed to a model
+ class_labels (`list[torch.LongTensor]`, *optional*):
+ list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
+ labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
+ output_auxiliary_logits (`bool`, *optional*):
+ Whether or not to output auxiliary logits.
+
+ Examples:
+
+ Semantic segmentation example:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, MaskFormerForInstanceSegmentation
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> # load MaskFormer fine-tuned on ADE20k semantic segmentation
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/maskformer-swin-base-ade")
+ >>> model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-base-ade")
+
+ >>> url = (
+ ... "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
+ ... )
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> # model predicts class_queries_logits of shape `(batch_size, num_queries)`
+ >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
+ >>> class_queries_logits = outputs.class_queries_logits
+ >>> masks_queries_logits = outputs.masks_queries_logits
+
+ >>> # you can pass them to image_processor for postprocessing
+ >>> predicted_semantic_map = image_processor.post_process_semantic_segmentation(
+ ... outputs, target_sizes=[(image.height, image.width)]
+ ... )[0]
+
+ >>> # we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
+ >>> list(predicted_semantic_map.shape)
+ [512, 683]
+ ```
+
+ Panoptic segmentation example:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, MaskFormerForInstanceSegmentation
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> # load MaskFormer fine-tuned on COCO panoptic segmentation
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/maskformer-swin-base-coco")
+ >>> model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-base-coco")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> # model predicts class_queries_logits of shape `(batch_size, num_queries)`
+ >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
+ >>> class_queries_logits = outputs.class_queries_logits
+ >>> masks_queries_logits = outputs.masks_queries_logits
+
+ >>> # you can pass them to image_processor for postprocessing
+ >>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(image.height, image.width)])[0]
+
+ >>> # we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
+ >>> predicted_panoptic_map = result["segmentation"]
+ >>> list(predicted_panoptic_map.shape)
+ [480, 640]
+ ```
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ raw_outputs = self.model(
+ pixel_values,
+ pixel_mask,
+ output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss,
+ return_dict=return_dict,
+ output_attentions=output_attentions,
+ )
+ # We need to have raw_outputs optionally be returned as a dict to use torch.compile. For backwards
+ # compatibility we convert to a dataclass for the rest of the model logic
+ outputs = MaskFormerModelOutput(
+ encoder_last_hidden_state=raw_outputs[0],
+ pixel_decoder_last_hidden_state=raw_outputs[1],
+ transformer_decoder_last_hidden_state=raw_outputs[2],
+ encoder_hidden_states=raw_outputs[3] if output_hidden_states else None,
+ pixel_decoder_hidden_states=raw_outputs[4] if output_hidden_states else None,
+ transformer_decoder_hidden_states=raw_outputs[5] if output_hidden_states else None,
+ hidden_states=raw_outputs[6] if output_hidden_states else None,
+ attentions=raw_outputs[-1] if output_attentions else None,
+ )
+
+ loss, loss_dict, auxiliary_logits = None, None, None
+
+ class_queries_logits, masks_queries_logits, auxiliary_logits = self.get_logits(outputs)
+
+ if mask_labels is not None and class_labels is not None:
+ loss_dict: dict[str, Tensor] = self.get_loss_dict(
+ masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits
+ )
+ loss = self.get_loss(loss_dict)
+
+ output_auxiliary_logits = (
+ self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits
+ )
+ if not output_auxiliary_logits:
+ auxiliary_logits = None
+
+ if not return_dict:
+ output = tuple(
+ v
+ for v in (loss, class_queries_logits, masks_queries_logits, auxiliary_logits, *outputs.values())
+ if v is not None
+ )
+ return output
+
+ return MaskFormerForInstanceSegmentationOutput(
+ loss=loss,
+ **outputs,
+ class_queries_logits=class_queries_logits,
+ masks_queries_logits=masks_queries_logits,
+ auxiliary_logits=auxiliary_logits,
+ )
+
+
+__all__ = ["MaskFormerForInstanceSegmentation", "MaskFormerModel", "MaskFormerPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/modeling_maskformer_swin.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/modeling_maskformer_swin.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a3e076a5a4c5e34da60cebae899ff77abb65bf6
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/maskformer/modeling_maskformer_swin.py
@@ -0,0 +1,927 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""MaskFormer Swin Transformer. The reason Swin Transformer is implemented here is because MaskFormer uses the hidden
+states before downsampling, which is different from the default Swin Transformer."""
+
+import collections.abc
+import math
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+from torch import Tensor, nn
+
+from ...activations import ACT2FN
+from ...file_utils import ModelOutput
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BackboneOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
+from ...utils import auto_docstring, torch_int
+from ...utils.backbone_utils import BackboneMixin
+from .configuration_maskformer_swin import MaskFormerSwinConfig
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Class for MaskFormerSwinModel's outputs that also contains the spatial dimensions of the hidden states.
+ """
+)
+class MaskFormerSwinModelOutputWithPooling(ModelOutput):
+ r"""
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state after a mean pooling operation.
+ hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
+ A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to
+ `batch, channels, height, width`. Due to padding, their spatial size cannot be inferred before the
+ `forward` method.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ hidden_states_spatial_dimensions: tuple[tuple[int, int]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Class for SwinEncoder's outputs.
+ """
+)
+class MaskFormerSwinBaseModelOutput(ModelOutput):
+ r"""
+ hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
+ A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to
+ `batch, channels, height, width`. Due to padding, their spatial size cannot inferred before the `forward`
+ method.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ hidden_states_spatial_dimensions: tuple[tuple[int, int]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+# Copied from transformers.models.swin.modeling_swin.window_partition
+def window_partition(input_feature, window_size):
+ """
+ Partitions the given input into windows.
+ """
+ batch_size, height, width, num_channels = input_feature.shape
+ input_feature = input_feature.view(
+ batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
+ )
+ windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
+ return windows
+
+
+# Copied from transformers.models.swin.modeling_swin.window_reverse
+def window_reverse(windows, window_size, height, width):
+ """
+ Merges windows to produce higher resolution features.
+ """
+ num_channels = windows.shape[-1]
+ windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
+ windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
+ return windows
+
+
+# Copied from transformers.models.swin.modeling_swin.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+class MaskFormerSwinEmbeddings(nn.Module):
+ """
+ Construct the patch and position embeddings.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.patch_embeddings = MaskFormerSwinPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.patch_grid = self.patch_embeddings.grid_size
+
+ if config.use_absolute_embeddings:
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
+ else:
+ self.position_embeddings = None
+
+ self.norm = nn.LayerNorm(config.embed_dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+
+ # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ class_pos_embed = self.position_embeddings[:, :1]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values, interpolate_pos_encoding):
+ _, num_channels, height, width = pixel_values.shape
+ embeddings, output_dimensions = self.patch_embeddings(pixel_values)
+ embeddings = self.norm(embeddings)
+
+ if self.position_embeddings is not None:
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings, output_dimensions
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->MaskFormerSwin
+class MaskFormerSwinPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.embed_dim
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+ self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def maybe_pad(self, pixel_values, height, width):
+ if width % self.patch_size[1] != 0:
+ pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
+ pixel_values = nn.functional.pad(pixel_values, pad_values)
+ if height % self.patch_size[0] != 0:
+ pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
+ pixel_values = nn.functional.pad(pixel_values, pad_values)
+ return pixel_values
+
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> tuple[torch.Tensor, tuple[int]]:
+ _, num_channels, height, width = pixel_values.shape
+ # pad the input to be divisible by self.patch_size, if needed
+ pixel_values = self.maybe_pad(pixel_values, height, width)
+ embeddings = self.projection(pixel_values)
+ _, _, height, width = embeddings.shape
+ output_dimensions = (height, width)
+ embeddings = embeddings.flatten(2).transpose(1, 2)
+
+ return embeddings, output_dimensions
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
+class MaskFormerSwinPatchMerging(nn.Module):
+ """
+ Patch Merging Layer.
+
+ Args:
+ input_resolution (`tuple[int]`):
+ Resolution of input feature.
+ dim (`int`):
+ Number of input channels.
+ norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
+ Normalization layer class.
+ """
+
+ def __init__(self, input_resolution: tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def maybe_pad(self, input_feature, height, width):
+ should_pad = (height % 2 == 1) or (width % 2 == 1)
+ if should_pad:
+ pad_values = (0, 0, 0, width % 2, 0, height % 2)
+ input_feature = nn.functional.pad(input_feature, pad_values)
+
+ return input_feature
+
+ def forward(self, input_feature: torch.Tensor, input_dimensions: tuple[int, int]) -> torch.Tensor:
+ height, width = input_dimensions
+ # `dim` is height * width
+ batch_size, dim, num_channels = input_feature.shape
+
+ input_feature = input_feature.view(batch_size, height, width, num_channels)
+ # pad input to be divisible by width and height, if needed
+ input_feature = self.maybe_pad(input_feature, height, width)
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_0 = input_feature[:, 0::2, 0::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_1 = input_feature[:, 1::2, 0::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_2 = input_feature[:, 0::2, 1::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_3 = input_feature[:, 1::2, 1::2, :]
+ # batch_size height/2 width/2 4*num_channels
+ input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
+ input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
+
+ input_feature = self.norm(input_feature)
+ input_feature = self.reduction(input_feature)
+
+ return input_feature
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->MaskFormerSwin
+class MaskFormerSwinDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return f"p={self.drop_prob}"
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin
+class MaskFormerSwinSelfAttention(nn.Module):
+ def __init__(self, config, dim, num_heads, window_size):
+ super().__init__()
+ if dim % num_heads != 0:
+ raise ValueError(
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
+ )
+
+ self.num_attention_heads = num_heads
+ self.attention_head_size = int(dim / num_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.window_size = (
+ window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
+ )
+
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
+ )
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
+ coords_flatten = torch.flatten(coords, 1)
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
+ relative_coords[:, :, 0] += self.window_size[0] - 1
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1)
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.Tensor]:
+ batch_size, dim, num_channels = hidden_states.shape
+ hidden_shape = (batch_size, dim, -1, self.attention_head_size)
+
+ query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
+ relative_position_bias = relative_position_bias.view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
+ )
+
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
+ attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in MaskFormerSwinModel forward() function)
+ mask_shape = attention_mask.shape[0]
+ attention_scores = attention_scores.view(
+ batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
+ )
+ attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
+ attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->MaskFormerSwin
+class MaskFormerSwinSelfOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, dim)
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->MaskFormerSwin
+class MaskFormerSwinAttention(nn.Module):
+ def __init__(self, config, dim, num_heads, window_size):
+ super().__init__()
+ self.self = MaskFormerSwinSelfAttention(config, dim, num_heads, window_size)
+ self.output = MaskFormerSwinSelfOutput(config, dim)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.Tensor]:
+ self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->MaskFormerSwin
+class MaskFormerSwinIntermediate(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->MaskFormerSwin
+class MaskFormerSwinOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class MaskFormerSwinLayer(nn.Module):
+ def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0):
+ super().__init__()
+ self.shift_size = shift_size
+ self.window_size = config.window_size
+ self.input_resolution = input_resolution
+ self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.attention = MaskFormerSwinAttention(config, dim, num_heads, self.window_size)
+ self.drop_path = MaskFormerSwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
+ self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.intermediate = MaskFormerSwinIntermediate(config, dim)
+ self.output = MaskFormerSwinOutput(config, dim)
+
+ def get_attn_mask(self, input_resolution):
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ height, width = input_resolution
+ img_mask = torch.zeros((1, height, width, 1))
+ height_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ width_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ count = 0
+ for height_slice in height_slices:
+ for width_slice in width_slices:
+ img_mask[:, height_slice, width_slice, :] = count
+ count += 1
+
+ mask_windows = window_partition(img_mask, self.window_size)
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0)
+ else:
+ attn_mask = None
+ return attn_mask
+
+ def maybe_pad(self, hidden_states, height, width):
+ pad_left = pad_top = 0
+ pad_right = (self.window_size - width % self.window_size) % self.window_size
+ pad_bottom = (self.window_size - height % self.window_size) % self.window_size
+ pad_values = (0, 0, pad_left, pad_right, pad_top, pad_bottom)
+ hidden_states = nn.functional.pad(hidden_states, pad_values)
+ return hidden_states, pad_values
+
+ def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
+ height, width = input_dimensions
+ batch_size, dim, channels = hidden_states.size()
+ shortcut = hidden_states
+
+ hidden_states = self.layernorm_before(hidden_states)
+ hidden_states = hidden_states.view(batch_size, height, width, channels)
+ # pad hidden_states to multiples of window size
+ hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
+
+ _, height_pad, width_pad, _ = hidden_states.shape
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_hidden_states = hidden_states
+
+ # partition windows
+ hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
+ hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
+ attn_mask = self.get_attn_mask((height_pad, width_pad))
+ if attn_mask is not None:
+ attn_mask = attn_mask.to(hidden_states_windows.device)
+
+ self_attention_outputs = self.attention(
+ hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
+ )
+
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
+ shifted_windows = window_reverse(
+ attention_windows, self.window_size, height_pad, width_pad
+ ) # B height' width' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ attention_windows = shifted_windows
+
+ was_padded = pad_values[3] > 0 or pad_values[5] > 0
+ if was_padded:
+ attention_windows = attention_windows[:, :height, :width, :].contiguous()
+
+ attention_windows = attention_windows.view(batch_size, height * width, channels)
+
+ hidden_states = shortcut + self.drop_path(attention_windows)
+
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+ layer_output = hidden_states + self.output(layer_output)
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+class MaskFormerSwinStage(GradientCheckpointingLayer):
+ # Copied from transformers.models.swin.modeling_swin.SwinStage.__init__ with Swin->MaskFormerSwin
+ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
+ super().__init__()
+ self.config = config
+ self.dim = dim
+ self.blocks = nn.ModuleList(
+ [
+ MaskFormerSwinLayer(
+ config=config,
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ drop_path_rate=drop_path[i],
+ shift_size=0 if (i % 2 == 0) else config.window_size // 2,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
+ else:
+ self.downsample = None
+
+ self.pointing = False
+
+ def forward(
+ self, hidden_states, input_dimensions, head_mask=None, output_attentions=False, output_hidden_states=False
+ ):
+ all_hidden_states = () if output_hidden_states else None
+
+ height, width = input_dimensions
+ for i, block_module in enumerate(self.blocks):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ block_hidden_states = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
+
+ hidden_states = block_hidden_states[0]
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.downsample is not None:
+ height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
+ output_dimensions = (height, width, height_downsampled, width_downsampled)
+ hidden_states = self.downsample(hidden_states, input_dimensions)
+ else:
+ output_dimensions = (height, width, height, width)
+
+ return hidden_states, output_dimensions, all_hidden_states
+
+
+class MaskFormerSwinEncoder(nn.Module):
+ # Copied from transformers.models.swin.modeling_swin.SwinEncoder.__init__ with Swin->MaskFormerSwin
+ def __init__(self, config, grid_size):
+ super().__init__()
+ self.num_layers = len(config.depths)
+ self.config = config
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
+ self.layers = nn.ModuleList(
+ [
+ MaskFormerSwinStage(
+ config=config,
+ dim=int(config.embed_dim * 2**i_layer),
+ input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
+ depth=config.depths[i_layer],
+ num_heads=config.num_heads[i_layer],
+ drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
+ downsample=MaskFormerSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
+ )
+ for i_layer in range(self.num_layers)
+ ]
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ input_dimensions,
+ head_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_input_dimensions = ()
+ all_self_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ for i, layer_module in enumerate(self.layers):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module(
+ hidden_states,
+ input_dimensions,
+ layer_head_mask,
+ output_attentions,
+ output_hidden_states,
+ )
+
+ input_dimensions = (output_dimensions[-2], output_dimensions[-1])
+ all_input_dimensions += (input_dimensions,)
+ if output_hidden_states:
+ all_hidden_states += (layer_all_hidden_states,)
+
+ hidden_states = layer_hidden_states
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_all_hidden_states[1],)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+ return MaskFormerSwinBaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ hidden_states_spatial_dimensions=all_input_dimensions,
+ attentions=all_self_attentions,
+ )
+
+
+@auto_docstring
+class MaskFormerSwinPreTrainedModel(PreTrainedModel):
+ config: MaskFormerSwinConfig
+ base_model_prefix = "model"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["MaskFormerSwinStage"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, MaskFormerSwinEmbeddings):
+ if module.position_embeddings is not None:
+ module.position_embeddings.data.zero_()
+ elif isinstance(module, MaskFormerSwinSelfAttention):
+ module.relative_position_bias_table.data.zero_()
+
+
+class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+ self.num_layers = len(config.depths)
+ self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
+
+ self.embeddings = MaskFormerSwinEmbeddings(config)
+ self.encoder = MaskFormerSwinEncoder(config, self.embeddings.patch_grid)
+
+ self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
+ self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def forward(
+ self,
+ pixel_values=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ interpolate_pos_encoding=False,
+ return_dict=None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, len(self.config.depths))
+
+ embedding_output, input_dimensions = self.embeddings(
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ input_dimensions,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs.last_hidden_state if return_dict else encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+
+ pooled_output = None
+ if self.pooler is not None:
+ pooled_output = self.pooler(sequence_output.transpose(1, 2))
+ pooled_output = torch.flatten(pooled_output, 1)
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions
+
+ return MaskFormerSwinModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ hidden_states_spatial_dimensions=hidden_states_spatial_dimensions,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
+ """
+ MaskFormerSwin backbone, designed especially for the MaskFormer framework.
+
+ This classes reshapes `hidden_states` from (`batch_size, sequence_length, hidden_size)` to (`batch_size,
+ num_channels, height, width)`). It also adds additional layernorms after each stage.
+
+ Args:
+ config (`MaskFormerSwinConfig`):
+ The configuration used by [`MaskFormerSwinModel`].
+ """
+
+ def __init__(self, config: MaskFormerSwinConfig):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.model = MaskFormerSwinModel(config)
+ if "stem" in self.out_features:
+ raise ValueError("This backbone does not support 'stem' in the `out_features`.")
+ self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
+ self.hidden_states_norms = nn.ModuleList(
+ [nn.LayerNorm(num_channels) for num_channels in self.num_features[1:]]
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ pixel_values: Tensor,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BackboneOutput:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ outputs = self.model(
+ pixel_values, output_hidden_states=True, output_attentions=output_attentions, return_dict=True
+ )
+
+ # we skip the stem
+ hidden_states = outputs.hidden_states[1:]
+
+ # we need to reshape the hidden states to their original spatial dimensions
+ # spatial dimensions contains all the heights and widths of each stage, including after the embeddings
+ spatial_dimensions: tuple[tuple[int, int]] = outputs.hidden_states_spatial_dimensions
+ feature_maps = ()
+ for i, (hidden_state, stage, (height, width)) in enumerate(
+ zip(hidden_states, self.stage_names[1:], spatial_dimensions)
+ ):
+ norm = self.hidden_states_norms[i]
+ # the last element correspond to the layer's last block output but before patch merging
+ hidden_state_unpolled = hidden_state[-1]
+ hidden_state_norm = norm(hidden_state_unpolled)
+ # the pixel decoder (FPN) expects 3D tensors (features)
+ batch_size, _, hidden_size = hidden_state_norm.shape
+ # reshape "b (h w) d -> b d h w"
+ hidden_state_permuted = (
+ hidden_state_norm.permute(0, 2, 1).view((batch_size, hidden_size, height, width)).contiguous()
+ )
+ if stage in self.out_features:
+ feature_maps += (hidden_state_permuted,)
+
+ if not return_dict:
+ output = (feature_maps,)
+ if output_hidden_states:
+ output += (outputs.hidden_states,)
+ if output_attentions:
+ output += (outputs.attentions,)
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["MaskFormerSwinBackbone", "MaskFormerSwinModel", "MaskFormerSwinPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mbart50/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mbart50/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7cd8c28da631fdd44cc09458380527b6323d044
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mbart50/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .tokenization_mbart50 import *
+ from .tokenization_mbart50_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mbart50/tokenization_mbart50.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mbart50/tokenization_mbart50.py
new file mode 100644
index 0000000000000000000000000000000000000000..413beaa03a83e4eddb9eb5a050d85c347ecbe2b5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mbart50/tokenization_mbart50.py
@@ -0,0 +1,359 @@
+# coding=utf-8
+# Copyright 2021 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from shutil import copyfile
+from typing import Any, Optional
+
+import sentencepiece as spm
+
+from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer
+from ...utils import logging
+from ...utils.import_utils import requires
+
+
+logger = logging.get_logger(__name__)
+
+SPIECE_UNDERLINE = "▁"
+
+VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
+
+
+FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID", "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF", "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA", "ur_PK", "xh_ZA", "gl_ES", "sl_SI"] # fmt: skip
+
+
+@requires(backends=("sentencepiece",))
+class MBart50Tokenizer(PreTrainedTokenizer):
+ """
+ Construct a MBart50 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ src_lang (`str`, *optional*):
+ A string representing the source language.
+ tgt_lang (`str`, *optional*):
+ A string representing the target language.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ sp_model_kwargs (`dict`, *optional*):
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+ to set:
+
+ - `enable_sampling`: Enable subword regularization.
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+ - `nbest_size = {0,1}`: No sampling is performed.
+ - `nbest_size > 1`: samples from the nbest_size results.
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+ using forward-filtering-and-backward-sampling algorithm.
+
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+ BPE-dropout.
+
+ Examples:
+
+ ```python
+ >>> from transformers import MBart50Tokenizer
+
+ >>> tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
+ >>> src_text = " UN Chief Says There Is No Military Solution in Syria"
+ >>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
+ >>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
+ >>> # model(**model_inputs) should work
+ ```"""
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ prefix_tokens: list[int] = []
+ suffix_tokens: list[int] = []
+
+ def __init__(
+ self,
+ vocab_file,
+ src_lang=None,
+ tgt_lang=None,
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ sp_model_kwargs: Optional[dict[str, Any]] = None,
+ **kwargs,
+ ) -> None:
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
+
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+
+ kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", []) or []
+ kwargs["additional_special_tokens"] += [
+ code for code in FAIRSEQ_LANGUAGE_CODES if code not in kwargs["additional_special_tokens"]
+ ]
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(str(vocab_file))
+ self.vocab_file = vocab_file
+
+ # Original fairseq vocab and spm vocab must be "aligned":
+ # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
+ # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
+ # fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-'
+ # spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
+
+ # Mimic fairseq token-to-id alignment for the first 4 token
+ self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3}
+
+ # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
+ self.fairseq_offset = 1
+
+ self.sp_model_size = len(self.sp_model)
+ self.lang_code_to_id = {
+ code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
+ }
+ self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
+ self.fairseq_tokens_to_ids[""] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
+
+ self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
+ self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
+
+ super().__init__(
+ src_lang=src_lang,
+ tgt_lang=tgt_lang,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ sp_model_kwargs=self.sp_model_kwargs,
+ **kwargs,
+ )
+
+ self._src_lang = src_lang if src_lang is not None else "en_XX"
+ self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
+ self.tgt_lang = tgt_lang
+ self.set_src_lang_special_tokens(self._src_lang)
+
+ @property
+ def vocab_size(self) -> int:
+ return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token
+
+ @property
+ def src_lang(self) -> str:
+ return self._src_lang
+
+ @src_lang.setter
+ def src_lang(self, new_src_lang: str) -> None:
+ self._src_lang = new_src_lang
+ self.set_src_lang_special_tokens(self._src_lang)
+
+ def __getstate__(self) -> dict:
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ return state
+
+ def __setstate__(self, d: dict) -> None:
+ self.__dict__ = d
+
+ # for backward compatibility
+ if not hasattr(self, "sp_model_kwargs"):
+ self.sp_model_kwargs = {}
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(self.vocab_file)
+
+ def get_vocab(self) -> dict:
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text: str) -> list[str]:
+ return self.sp_model.encode(text, out_type=str)
+
+ def _convert_token_to_id(self, token: str) -> int:
+ """Converts a token (str) in an id using the vocab."""
+ if token in self.fairseq_tokens_to_ids:
+ return self.fairseq_tokens_to_ids[token]
+ spm_id = self.sp_model.PieceToId(token)
+
+ # Need to return unknown token if the SP model returned 0
+ return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
+
+ def _convert_id_to_token(self, index: int) -> str:
+ """Converts an index (integer) in a token (str) using the vocab."""
+ if index in self.fairseq_ids_to_tokens:
+ return self.fairseq_ids_to_tokens[index]
+ return self.sp_model.IdToPiece(index - self.fairseq_offset)
+
+ # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.convert_tokens_to_string
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ current_sub_tokens = []
+ out_string = ""
+ prev_is_special = False
+ for token in tokens:
+ # make sure that special tokens are not decoded using sentencepiece model
+ if token in self.all_special_tokens:
+ if not prev_is_special:
+ out_string += " "
+ out_string += self.sp_model.decode(current_sub_tokens) + token
+ prev_is_special = True
+ current_sub_tokens = []
+ else:
+ current_sub_tokens.append(token)
+ prev_is_special = False
+ out_string += self.sp_model.decode(current_sub_tokens)
+ return out_string.strip()
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ prefix_ones = [1] * len(self.prefix_tokens)
+ suffix_ones = [1] * len(self.suffix_tokens)
+ if token_ids_1 is None:
+ return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
+ return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. An MBART-50 sequence has the following format, where `X` represents the sequence:
+
+ - `input_ids` (for encoder) `[src_lang_code] X [eos]`
+ - `labels`: (for decoder) `[tgt_lang_code] X [eos]`
+
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+ separator.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return self.prefix_tokens + token_ids_0 + self.suffix_tokens
+ # We don't expect to process pairs, but leave the pair logic for API consistency
+ return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
+
+ def _build_translation_inputs(
+ self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
+ ):
+ """Used by translation pipeline, to prepare inputs for the generate function"""
+ if src_lang is None or tgt_lang is None:
+ raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
+ self.src_lang = src_lang
+ inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
+ tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
+ inputs["forced_bos_token_id"] = tgt_lang_id
+ return inputs
+
+ def prepare_seq2seq_batch(
+ self,
+ src_texts: list[str],
+ src_lang: str = "en_XX",
+ tgt_texts: Optional[list[str]] = None,
+ tgt_lang: str = "ro_RO",
+ **kwargs,
+ ) -> BatchEncoding:
+ self.src_lang = src_lang
+ self.tgt_lang = tgt_lang
+ return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
+
+ def _switch_to_input_mode(self):
+ return self.set_src_lang_special_tokens(self.src_lang)
+
+ def _switch_to_target_mode(self):
+ return self.set_tgt_lang_special_tokens(self.tgt_lang)
+
+ def set_src_lang_special_tokens(self, src_lang: str) -> None:
+ """Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
+ self.cur_lang_code_id = self.lang_code_to_id[src_lang]
+ self.prefix_tokens = [self.cur_lang_code_id]
+ self.suffix_tokens = [self.eos_token_id]
+
+ def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None:
+ """Reset the special tokens to the target language setting. prefix=[tgt_lang_code] and suffix=[eos]."""
+ self.cur_lang_code_id = self.lang_code_to_id[tgt_lang]
+ self.prefix_tokens = [self.cur_lang_code_id]
+ self.suffix_tokens = [self.eos_token_id]
+
+
+__all__ = ["MBart50Tokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mbart50/tokenization_mbart50_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mbart50/tokenization_mbart50_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..985b0929f87c5516a2edd8e02c6aaaa8926d496e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mbart50/tokenization_mbart50_fast.py
@@ -0,0 +1,258 @@
+# coding=utf-8
+# Copyright 2021 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from shutil import copyfile
+from typing import Optional
+
+from tokenizers import processors
+
+from ...tokenization_utils import AddedToken, BatchEncoding
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import is_sentencepiece_available, logging
+
+
+if is_sentencepiece_available():
+ from .tokenization_mbart50 import MBart50Tokenizer
+else:
+ MBart50Tokenizer = None
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
+
+
+FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID", "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF", "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA", "ur_PK", "xh_ZA", "gl_ES", "sl_SI"] # fmt: skip
+
+
+class MBart50TokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" MBART tokenizer for mBART-50 (backed by HuggingFace's *tokenizers* library). Based on
+ [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ src_lang (`str`, *optional*):
+ A string representing the source language.
+ tgt_lang (`str`, *optional*):
+ A string representing the target language.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+
+ Examples:
+
+ ```python
+ >>> from transformers import MBart50TokenizerFast
+
+ >>> tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
+ >>> src_text = " UN Chief Says There Is No Military Solution in Syria"
+ >>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
+ >>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
+ >>> # model(**model_inputs) should work
+ ```"""
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = MBart50Tokenizer
+
+ prefix_tokens: list[int] = []
+ suffix_tokens: list[int] = []
+
+ def __init__(
+ self,
+ vocab_file=None,
+ src_lang=None,
+ tgt_lang=None,
+ tokenizer_file=None,
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ **kwargs,
+ ):
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
+
+ kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", []) or []
+ kwargs["additional_special_tokens"] += [
+ code for code in FAIRSEQ_LANGUAGE_CODES if code not in kwargs["additional_special_tokens"]
+ ]
+
+ super().__init__(
+ vocab_file,
+ src_lang=src_lang,
+ tgt_lang=tgt_lang,
+ tokenizer_file=tokenizer_file,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ **kwargs,
+ )
+
+ self.vocab_file = vocab_file
+
+ self.lang_code_to_id = {
+ lang_code: self.convert_tokens_to_ids(lang_code) for lang_code in FAIRSEQ_LANGUAGE_CODES
+ }
+
+ self._src_lang = src_lang if src_lang is not None else "en_XX"
+ self.tgt_lang = tgt_lang
+ self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
+ self.set_src_lang_special_tokens(self._src_lang)
+
+ @property
+ def src_lang(self) -> str:
+ return self._src_lang
+
+ @src_lang.setter
+ def src_lang(self, new_src_lang: str) -> None:
+ self._src_lang = new_src_lang
+ self.set_src_lang_special_tokens(self._src_lang)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. The special tokens depend on calling set_lang.
+
+ An MBART-50 sequence has the following format, where `X` represents the sequence:
+
+ - `input_ids` (for encoder) `[src_lang_code] X [eos]`
+ - `labels`: (for decoder) `[tgt_lang_code] X [eos]`
+
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+ separator.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return self.prefix_tokens + token_ids_0 + self.suffix_tokens
+ # We don't expect to process pairs, but leave the pair logic for API consistency
+ return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
+
+ def prepare_seq2seq_batch(
+ self,
+ src_texts: list[str],
+ src_lang: str = "en_XX",
+ tgt_texts: Optional[list[str]] = None,
+ tgt_lang: str = "ro_RO",
+ **kwargs,
+ ) -> BatchEncoding:
+ self.src_lang = src_lang
+ self.tgt_lang = tgt_lang
+ return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
+
+ def _switch_to_input_mode(self):
+ return self.set_src_lang_special_tokens(self.src_lang)
+
+ def _switch_to_target_mode(self):
+ return self.set_tgt_lang_special_tokens(self.tgt_lang)
+
+ def set_src_lang_special_tokens(self, src_lang: str) -> None:
+ """Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
+ self.cur_lang_code_id = self.convert_tokens_to_ids(src_lang)
+ self.prefix_tokens = [self.cur_lang_code_id]
+ self.suffix_tokens = [self.eos_token_id]
+
+ prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
+ suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
+
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
+ pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
+ special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
+ )
+
+ def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None:
+ """Reset the special tokens to the target language setting. prefix=[src_lang_code] and suffix=[eos]."""
+ self.cur_lang_code_id = self.convert_tokens_to_ids(tgt_lang)
+ self.prefix_tokens = [self.cur_lang_code_id]
+ self.suffix_tokens = [self.eos_token_id]
+
+ prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
+ suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
+
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
+ pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
+ special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
+ )
+
+ def _build_translation_inputs(
+ self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
+ ):
+ """Used by translation pipeline, to prepare inputs for the generate function"""
+ if src_lang is None or tgt_lang is None:
+ raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
+ self.src_lang = src_lang
+ inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
+ tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
+ inputs["forced_bos_token_id"] = tgt_lang_id
+ return inputs
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not self.can_save_slow_tokenizer:
+ raise ValueError(
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+ "tokenizer."
+ )
+
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
+
+
+__all__ = ["MBart50TokenizerFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mimi/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mimi/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ed44a4324ddbf586f5aa9e1971422c98755eb3a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mimi/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_mimi import *
+ from .modeling_mimi import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mimi/configuration_mimi.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mimi/configuration_mimi.py
new file mode 100644
index 0000000000000000000000000000000000000000..c53ce475f9e05f3e5730280677c17466d884aca3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mimi/configuration_mimi.py
@@ -0,0 +1,279 @@
+# coding=utf-8
+# Copyright 2024 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Mimi model configuration"""
+
+import math
+
+import numpy as np
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MimiConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of an [`MimiModel`]. It is used to instantiate a
+ Mimi model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the
+ [kyutai/mimi](https://huggingface.co/kyutai/mimi) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ sampling_rate (`int`, *optional*, defaults to 24000):
+ The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
+ frame_rate (`float`, *optional*):
+ Should be computed from the other parameters, yet kept for backward compatibility.
+ audio_channels (`int`, *optional*, defaults to 1):
+ Number of channels in the audio data. Either 1 for mono or 2 for stereo.
+ hidden_size (`int`, *optional*, defaults to 512):
+ Intermediate representation dimension.
+ num_filters (`int`, *optional*, defaults to 64):
+ Number of convolution kernels of first `MimiConv1d` down sampling layer.
+ num_residual_layers (`int`, *optional*, defaults to 1):
+ Number of residual layers.
+ upsampling_ratios (`Sequence[int]`, *optional*):
+ Kernel size and stride ratios. The encoder uses downsampling ratios instead of upsampling ratios, hence it
+ will use the ratios in the reverse order to the ones specified here that must match the decoder order.
+ If not specified, will defaults to `[8, 6, 5, 4]`
+ kernel_size (`int`, *optional*, defaults to 7):
+ Kernel size for the initial convolution.
+ last_kernel_size (`int`, *optional*, defaults to 3):
+ Kernel size for the last convolution layer.
+ residual_kernel_size (`int`, *optional*, defaults to 3):
+ Kernel size for the residual layers.
+ dilation_growth_rate (`int`, *optional*, defaults to 2):
+ How much to increase the dilation with each layer.
+ use_causal_conv (`bool`, *optional*, defaults to `True`):
+ Whether to use fully causal convolution.
+ pad_mode (`str`, *optional*, defaults to `"constant"`):
+ Padding mode for the convolutions.
+ compress (`int`, *optional*, defaults to 2):
+ Reduced dimensionality in residual branches.
+ trim_right_ratio (`float`, *optional*, defaults to 1.0):
+ Ratio for trimming at the right of the transposed convolution under the `use_causal_conv = True` setup. If
+ equal to 1.0, it means that all the trimming is done at the right.
+ codebook_size (`int`, *optional*, defaults to 2048):
+ Number of discret codes in each codebooks.
+ codebook_dim (`int`, *optional*, defaults to 256):
+ Dimension of the unquantized codebook vectors. If not defined, uses `hidden_size`.
+ num_quantizers (`int`, *optional*, defaults to 32):
+ Number of quantizer channels, or codebooks, in the quantizer.
+ use_conv_shortcut (`bool`, *optional*, defaults to `False`):
+ Whether to use a convolutional layer as the 'skip' connection in the `MimiResnetBlock` block. If False,
+ an identity function will be used, giving a generic residual connection.
+ vector_quantization_hidden_dimension (`int`, *optional*, defaults to 256):
+ Intermediate representation dimension in the residual vector quantization space.
+ num_semantic_quantizers (`int`, *optional*, defaults to 1):
+ Number of semantic quantizer channels, or codebooks, in the semantic quantizer. Must be lower than `num_quantizers`.
+ upsample_groups (`int`, *optional*, defaults to 512):
+ If `frame_rate!=encodec_frame_rate`, indicates the number of groups used in the upsampling operation to go from one rate to another.
+ num_hidden_layers (`int`, *optional*, defaults to 8):
+ Number of hidden layers in the Transformer models.
+ intermediate_size (`int`, *optional*, defaults to 2048):
+ Dimension of the MLP representations.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
+ head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
+ The attention head dimension.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 8000):
+ The maximum sequence length that this model might ever be used with. Mimi's sliding window attention
+ allows sequence of up to 8000 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the LayerNorm normalization layers.
+ use_cache (`bool`, *optional*, defaults to `False`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ use_streaming (`bool`, *optional*, defaults to `False`):
+ Whether to use streaming mode. If `True`, the model encode method will return the padding cache that can be used in a subsequent call to the encode method.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ sliding_window (`int`, *optional*, defaults to 250):
+ Sliding window attention window size. If not specified, will default to `250`.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ layer_scale_initial_scale (`float`, *optional*, defaults to 0.01):
+ Initial scale of the residual rescaling operation done in the Transformer models.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ Example:
+
+ ```python
+ >>> from transformers import MimiModel, MimiConfig
+
+ >>> # Initializing a "kyutai/mimi" style configuration
+ >>> configuration = MimiConfig()
+
+ >>> # Initializing a model (with random weights) from the "kyutai/mimi" style configuration
+ >>> model = MimiModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "mimi"
+
+ def __init__(
+ self,
+ sampling_rate=24_000,
+ frame_rate=None,
+ audio_channels=1,
+ hidden_size=512,
+ num_filters=64,
+ num_residual_layers=1,
+ upsampling_ratios=None,
+ kernel_size=7,
+ last_kernel_size=3,
+ residual_kernel_size=3,
+ dilation_growth_rate=2,
+ use_causal_conv=True,
+ pad_mode="constant",
+ compress=2,
+ trim_right_ratio=1.0,
+ codebook_size=2048,
+ codebook_dim=256,
+ num_quantizers=32,
+ use_conv_shortcut=False,
+ vector_quantization_hidden_dimension=256,
+ num_semantic_quantizers=1,
+ upsample_groups=512,
+ num_hidden_layers=8,
+ intermediate_size=2048,
+ num_attention_heads=8,
+ num_key_value_heads=8,
+ head_dim=None,
+ hidden_act="gelu",
+ max_position_embeddings=8000,
+ initializer_range=0.02,
+ norm_eps=1e-5,
+ use_cache=False,
+ use_streaming=False,
+ rope_theta=10000.0,
+ sliding_window=250,
+ attention_dropout=0.0,
+ layer_scale_initial_scale=0.01,
+ attention_bias=False,
+ **kwargs,
+ ):
+ self.sampling_rate = sampling_rate
+ self.audio_channels = audio_channels
+ self.hidden_size = hidden_size
+ self.num_filters = num_filters
+ self.num_residual_layers = num_residual_layers
+ self.upsampling_ratios = upsampling_ratios if upsampling_ratios else [8, 6, 5, 4]
+ self.kernel_size = kernel_size
+ self.last_kernel_size = last_kernel_size
+ self.residual_kernel_size = residual_kernel_size
+ self.dilation_growth_rate = dilation_growth_rate
+ self.use_causal_conv = use_causal_conv
+ self.pad_mode = pad_mode
+ self.compress = compress
+ self.trim_right_ratio = trim_right_ratio
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
+ self.num_quantizers = num_quantizers
+ self.use_conv_shortcut = use_conv_shortcut
+ self.vector_quantization_hidden_dimension = vector_quantization_hidden_dimension
+ self.upsample_groups = upsample_groups
+ self.num_hidden_layers = num_hidden_layers
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.norm_eps = norm_eps
+ self.use_cache = use_cache
+ self.use_streaming = use_streaming
+ self.rope_theta = rope_theta
+ self.sliding_window = sliding_window
+ self.attention_dropout = attention_dropout
+ self.head_dim = head_dim or hidden_size // num_attention_heads
+ self.layer_scale_initial_scale = layer_scale_initial_scale
+ self.attention_bias = attention_bias
+
+ # Handle backward compatibility for frame_rate:
+ # If frame_rate is explicitly provided, use it (backward compatibility)
+ # Otherwise, compute it from other parameters (correctly)
+ if frame_rate is not None:
+ self._frame_rate = frame_rate
+ else:
+ self._frame_rate = None
+
+ if num_semantic_quantizers >= self.num_quantizers:
+ raise ValueError(
+ f"The number of semantic quantizers should be lower than the total number of quantizers {self.num_quantizers}, but is currently {num_semantic_quantizers}."
+ )
+ self.num_semantic_quantizers = num_semantic_quantizers
+ super().__init__(**kwargs)
+
+ @property
+ def encodec_frame_rate(self) -> int:
+ hop_length = np.prod(self.upsampling_ratios)
+ return math.ceil(self.sampling_rate / hop_length)
+
+ @property
+ def num_codebooks(self) -> int:
+ # alias to num_quantizers
+ return self.num_quantizers
+
+ @property
+ def frame_size(self) -> int:
+ # 1. we need each encoder conv stride
+ # first conv
+ strides = [1]
+
+ # layer convs
+ for ratio in reversed(self.upsampling_ratios):
+ for j in range(self.num_residual_layers):
+ len_kernel_sizes = len(self.residual_kernel_size) if isinstance(self.residual_kernel_size, list) else 1
+ strides.extend([1] * (len_kernel_sizes + 1))
+ if self.use_conv_shortcut: # skip connection
+ strides.append(1)
+
+ strides.append(ratio)
+
+ # last conv
+ strides.append(1)
+
+ # downsampling layer
+ strides.append(2)
+
+ return math.prod(strides)
+
+ @property
+ def frame_rate(self) -> float:
+ # handle backward compatibility
+ if self._frame_rate is not None:
+ return self._frame_rate
+ return self.sampling_rate / self.frame_size
+
+
+__all__ = ["MimiConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mimi/modeling_mimi.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mimi/modeling_mimi.py
new file mode 100644
index 0000000000000000000000000000000000000000..f22cad968247dd690fb28453dc9d986d4c5d47b7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mimi/modeling_mimi.py
@@ -0,0 +1,1753 @@
+# coding=utf-8
+# Copyright 2024 Kyutai, and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Mimi model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import PreTrainedModel
+from ...utils import ModelOutput, auto_docstring, logging
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_mimi import MimiConfig
+
+
+if is_flash_attn_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring
+class MimiOutput(ModelOutput):
+ r"""
+ audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
+ Discret code embeddings computed using `model.encode`.
+ audio_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Decoded audio values, obtained using the decoder part of Mimi.
+ encoder_past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer.
+ This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ The model will output the same cache format that is fed as input.
+
+ If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+ have their past key value states given to this model).
+ decoder_past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer.
+ This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ The model will output the same cache format that is fed as input.
+
+ If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+ have their past key value states given to this model).
+ """
+
+ audio_codes: Optional[torch.LongTensor] = None
+ audio_values: Optional[torch.FloatTensor] = None
+ encoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None
+ decoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None
+
+
+class MimiConv1dPaddingCache:
+ """
+ Padding cache for MimiConv1d causal convolutions in order to support streaming via cache padding.
+ See: https://huggingface.co/papers/2005.06720 & https://huggingface.co/papers/2204.07064
+
+ A padding cache is a list of cached partial hidden states for each convolution layer.
+ Hidden states are cached from the previous call to the MimiConv1d forward pass, given the padding size.
+ """
+
+ def __init__(
+ self,
+ num_layers: int,
+ per_layer_padding: list[int],
+ per_layer_padding_mode: list[str],
+ per_layer_in_channels: list[int],
+ ):
+ # ensure correct number of layers for each arg
+ from_args_num_layers = {len(per_layer_padding), len(per_layer_padding_mode), len(per_layer_in_channels)}
+
+ if len(from_args_num_layers) != 1 or from_args_num_layers.pop() != num_layers:
+ raise ValueError(
+ f"Expected `num_layers` ({num_layers}) values in `per_layer_padding`, `per_layer_padding_mode` and `per_layer_in_channels`"
+ )
+ elif not all(mode in ["constant", "replicate"] for mode in per_layer_padding_mode):
+ raise NotImplementedError(
+ "`padding_cache` is not supported for convolutions using other than `constant` or `replicate` padding mode"
+ )
+
+ self.per_layer_padding = per_layer_padding
+ self.per_layer_padding_mode = per_layer_padding_mode
+ self.per_layer_in_channels = per_layer_in_channels
+ self.per_layer_is_init = [True] * num_layers
+
+ self.padding_cache = [None] * num_layers
+
+ def update(self, hidden_states: torch.Tensor, layer_idx: int):
+ """
+ Updates the padding cache with the new padding states for the layer `layer_idx` and returns the current cache.
+
+ Parameters:
+ hidden_states (`torch.Tensor`):
+ The hidden states to be partially cached.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ Returns:
+ `torch.Tensor` or `None`, the current padding cache.
+ """
+ batch_size, dtype, device = hidden_states.shape[0], hidden_states.dtype, hidden_states.device
+ padding = self.per_layer_padding[layer_idx]
+ padding_mode = self.per_layer_padding_mode[layer_idx]
+ in_channels = self.per_layer_in_channels[layer_idx]
+
+ if self.padding_cache[layer_idx] is None:
+ if padding_mode == "constant":
+ current_cache = torch.zeros(
+ batch_size,
+ in_channels,
+ padding,
+ device=device,
+ dtype=dtype,
+ )
+ elif padding_mode == "replicate":
+ current_cache = (
+ torch.ones(
+ batch_size,
+ in_channels,
+ padding,
+ device=device,
+ dtype=dtype,
+ )
+ * hidden_states[..., :1]
+ )
+ else:
+ current_cache = self.padding_cache[layer_idx]
+
+ # update the cache
+ if padding > 0:
+ padding_states = hidden_states[:, :, -padding:]
+ else:
+ padding_states = torch.empty(batch_size, in_channels, padding, dtype=dtype, device=device)
+ self.padding_cache[layer_idx] = padding_states
+
+ return current_cache
+
+
+@dataclass
+@auto_docstring
+class MimiEncoderOutput(ModelOutput):
+ r"""
+ audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
+ Discret code embeddings computed using `model.encode`.
+ encoder_past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer.
+ This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ The model will output the same cache format that is fed as input.
+
+ If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+ have their past key value states given to this model).
+ padding_cache (`MimiConv1dPaddingCache`, *optional*):
+ Padding cache for MimiConv1d causal convolutions in order to support streaming via cache padding.
+ """
+
+ audio_codes: Optional[torch.LongTensor] = None
+ encoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None
+ padding_cache: Optional[MimiConv1dPaddingCache] = None
+
+
+@dataclass
+@auto_docstring
+class MimiDecoderOutput(ModelOutput):
+ r"""
+ audio_values (`torch.FloatTensor` of shape `(batch_size, segment_length)`, *optional*):
+ Decoded audio values, obtained using the decoder part of Mimi.
+ decoder_past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer.
+ This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ The model will output the same cache format that is fed as input.
+
+ If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+ have their past key value states given to this model).
+ """
+
+ audio_values: Optional[torch.FloatTensor] = None
+ decoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None
+
+
+class MimiConv1d(nn.Module):
+ """Conv1d with asymmetric or causal padding and normalization."""
+
+ def __init__(
+ self,
+ config,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ dilation: int = 1,
+ groups: int = 1,
+ pad_mode: Optional[str] = None,
+ bias: bool = True,
+ layer_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ self.causal = config.use_causal_conv
+ self.pad_mode = config.pad_mode if pad_mode is None else pad_mode
+ self.layer_idx = layer_idx
+ self.in_channels = in_channels
+
+ # warn user on unusual setup between dilation and stride
+ if stride > 1 and dilation > 1:
+ logger.warning(
+ "MimiConv1d has been initialized with stride > 1 and dilation > 1"
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
+ )
+
+ self.conv = nn.Conv1d(
+ in_channels, out_channels, kernel_size, stride, dilation=dilation, groups=groups, bias=bias
+ )
+
+ kernel_size = self.conv.kernel_size[0]
+ stride = torch.tensor(self.conv.stride[0], dtype=torch.int64)
+ dilation = self.conv.dilation[0]
+
+ # Effective kernel size with dilations.
+ kernel_size = torch.tensor((kernel_size - 1) * dilation + 1, dtype=torch.int64)
+
+ self.register_buffer("stride", stride, persistent=False)
+ self.register_buffer("kernel_size", kernel_size, persistent=False)
+ self.register_buffer("padding_total", kernel_size - stride, persistent=False)
+
+ # Asymmetric padding required for odd strides
+ self.padding_right = self.padding_total // 2
+ self.padding_left = self.padding_total - self.padding_right
+
+ def apply_weight_norm(self):
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
+
+ weight_norm(self.conv)
+
+ def remove_weight_norm(self):
+ nn.utils.remove_weight_norm(self.conv)
+
+ # Copied from transformers.models.encodec.modeling_encodec.EncodecConv1d._get_extra_padding_for_conv1d
+ def _get_extra_padding_for_conv1d(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ """See `pad_for_conv1d`."""
+ length = hidden_states.shape[-1]
+ n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
+ n_frames = torch.ceil(n_frames).to(torch.int64) - 1
+ ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
+
+ return ideal_length - length
+
+ @staticmethod
+ # Copied from transformers.models.encodec.modeling_encodec.EncodecConv1d._pad1d
+ def _pad1d(hidden_states: torch.Tensor, paddings: tuple[int, int], mode: str = "zero", value: float = 0.0):
+ """Tiny wrapper around torch.nn.functional.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right before the reflection happens.
+ """
+ length = hidden_states.shape[-1]
+ padding_left, padding_right = paddings
+ if mode != "reflect":
+ return nn.functional.pad(hidden_states, paddings, mode, value)
+
+ max_pad = max(padding_left, padding_right)
+ extra_pad = 0
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ hidden_states = nn.functional.pad(hidden_states, (0, extra_pad))
+ padded = nn.functional.pad(hidden_states, paddings, mode, value)
+ end = padded.shape[-1] - extra_pad
+ return padded[..., :end]
+
+ def _get_output_length(self, input_length: torch.LongTensor) -> torch.LongTensor:
+ """
+ Return the length of the output of the MimiConv1d.
+ """
+ # padding size
+ n_frames = (input_length - self.kernel_size + self.padding_total) / self.stride + 1
+ n_frames = torch.ceil(n_frames).to(torch.int64) - 1
+ ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
+ extra_padding = ideal_length - input_length
+
+ if self.causal:
+ padding_left = self.padding_total
+ padding_right = extra_padding
+ else:
+ padding_left = self.padding_left
+ padding_right = self.padding_right + extra_padding
+
+ # padding
+ input_length = input_length + padding_left + padding_right
+
+ # conv
+ output_length = (
+ input_length + 2 * self.conv.padding[0] - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1
+ ) // self.conv.stride[0] + 1
+ return output_length
+
+ def forward(self, hidden_states, padding_cache=None):
+ extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
+
+ if not self.causal and padding_cache is not None:
+ raise ValueError("`padding_cache` is not supported for non-causal convolutions.")
+
+ if self.causal and padding_cache is not None:
+ layer_padding_cache = padding_cache.update(hidden_states, self.layer_idx)
+ hidden_states = torch.cat([layer_padding_cache, hidden_states], dim=2)
+
+ elif self.causal:
+ # Left padding for causal
+ hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode)
+
+ else:
+ hidden_states = self._pad1d(
+ hidden_states, (self.padding_left, self.padding_right + extra_padding), mode=self.pad_mode
+ )
+
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class MimiConvTranspose1d(nn.Module):
+ """ConvTranspose1d with asymmetric or causal padding and normalization."""
+
+ def __init__(
+ self,
+ config,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ groups: int = 1,
+ bias=True,
+ ):
+ super().__init__()
+ self.causal = config.use_causal_conv
+ self.trim_right_ratio = config.trim_right_ratio
+ self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias)
+
+ if not (self.causal or self.trim_right_ratio == 1.0):
+ raise ValueError("`trim_right_ratio` != 1.0 only makes sense for causal convolutions")
+
+ kernel_size = self.conv.kernel_size[0]
+ stride = self.conv.stride[0]
+ padding_total = kernel_size - stride
+
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+ # removed at the very end, when keeping only the right length for the output,
+ # as removing it here would require also passing the length at the matching layer
+ # in the encoder.
+ if self.causal:
+ # Trim the padding on the right according to the specified ratio
+ # if trim_right_ratio = 1.0, trim everything from right
+ self.padding_right = math.ceil(padding_total * self.trim_right_ratio)
+ else:
+ # Asymmetric padding required for odd strides
+ self.padding_right = padding_total // 2
+
+ self.padding_left = padding_total - self.padding_right
+
+ def apply_weight_norm(self):
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
+
+ weight_norm(self.conv)
+
+ def remove_weight_norm(self):
+ nn.utils.remove_weight_norm(self.conv)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ # unpad
+ end = hidden_states.shape[-1] - self.padding_right
+ hidden_states = hidden_states[..., self.padding_left : end]
+ return hidden_states
+
+
+class MimiResnetBlock(nn.Module):
+ """
+ Residual block from SEANet model as used by Mimi.
+ """
+
+ def __init__(self, config: MimiConfig, dim: int, dilations: list[int]):
+ super().__init__()
+ kernel_sizes = (config.residual_kernel_size, 1)
+ if len(kernel_sizes) != len(dilations):
+ raise ValueError("Number of kernel sizes should match number of dilations")
+
+ hidden = dim // config.compress
+ block = []
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+ in_chs = dim if i == 0 else hidden
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+ block += [nn.ELU()]
+ block += [MimiConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)]
+ self.block = nn.ModuleList(block)
+
+ if config.use_conv_shortcut:
+ self.shortcut = MimiConv1d(config, dim, dim, kernel_size=1)
+ else:
+ self.shortcut = nn.Identity()
+
+ def forward(self, hidden_states, padding_cache=None):
+ residual = hidden_states
+
+ for layer in self.block:
+ if isinstance(layer, MimiConv1d):
+ hidden_states = layer(hidden_states, padding_cache=padding_cache)
+ else:
+ hidden_states = layer(hidden_states)
+
+ if isinstance(self.shortcut, MimiConv1d):
+ residual = self.shortcut(residual, padding_cache=padding_cache)
+ else:
+ residual = self.shortcut(residual)
+
+ return residual + hidden_states
+
+
+class MimiEncoder(nn.Module):
+ """SEANet encoder as used by Mimi."""
+
+ def __init__(self, config: MimiConfig):
+ super().__init__()
+ model = [MimiConv1d(config, config.audio_channels, config.num_filters, config.kernel_size)]
+ scaling = 1
+
+ # keep track of MimiConv1d submodule layer names for easy encoded length computation
+ mimiconv1d_layer_names = ["layers.0"]
+
+ # Downsample to raw audio scale
+ for ratio in reversed(config.upsampling_ratios):
+ current_scale = scaling * config.num_filters
+ # Add residual layers
+ for j in range(config.num_residual_layers):
+ mimiconv1d_layer_names.extend([f"layers.{len(model)}.block.1", f"layers.{len(model)}.block.3"])
+ model += [MimiResnetBlock(config, current_scale, [config.dilation_growth_rate**j, 1])]
+ # Add downsampling layers
+ model += [nn.ELU()]
+ mimiconv1d_layer_names.append(f"layers.{len(model)}")
+ model += [MimiConv1d(config, current_scale, current_scale * 2, kernel_size=ratio * 2, stride=ratio)]
+ scaling *= 2
+
+ model += [nn.ELU()]
+ mimiconv1d_layer_names.append(f"layers.{len(model)}")
+ model += [MimiConv1d(config, scaling * config.num_filters, config.hidden_size, config.last_kernel_size)]
+
+ self.layers = nn.ModuleList(model)
+ self._mimiconv1d_layer_names = mimiconv1d_layer_names
+
+ # initialize layer_idx for MimiConv1d submodules, necessary for padding_cache
+ for layer_idx, layername in enumerate(self._mimiconv1d_layer_names):
+ conv_layer = self.get_submodule(layername)
+ setattr(conv_layer, "layer_idx", layer_idx)
+
+ def forward(self, hidden_states, padding_cache=None):
+ for layer in self.layers:
+ if isinstance(layer, (MimiConv1d, MimiResnetBlock)):
+ hidden_states = layer(hidden_states, padding_cache=padding_cache)
+ else:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+
+class MimiLayerScale(nn.Module):
+ """Layer scale from [Touvron et al 2021] (https://huggingface.co/papers/2103.17239).
+ This rescales diagonally the residual outputs close to 0, with a learnt scale.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ channels = config.hidden_size
+ initial_scale = config.layer_scale_initial_scale
+ self.scale = nn.Parameter(torch.full((channels,), initial_scale, requires_grad=True))
+
+ def forward(self, x: torch.Tensor):
+ return self.scale * x
+
+
+# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mimi
+class MimiRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: MimiConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class MimiMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP.forward
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+# copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Mimi
+# no longer copied after attention refactors
+class MimiAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.head_dim
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+ self.scaling = 1 / math.sqrt(config.head_dim)
+
+ if self.hidden_size % self.num_heads != 0:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+ self.rotary_emb = MimiRotaryEmbedding(config)
+ self.sliding_window = config.sliding_window # Ignore copy
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.view(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi
+# TODO cyril: modular
+class MimiFlashAttention2(MimiAttention):
+ """
+ Mimi flash attention module. This module inherits from `MimiAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ if isinstance(past_key_values, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (MimiRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = (
+ torch.get_autocast_dtype(device_type)
+ if hasattr(torch, "get_autocast_dtype")
+ else torch.get_autocast_gpu_dtype()
+ )
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi
+# TODO cyril: modular
+class MimiSdpaAttention(MimiAttention):
+ """
+ Mimi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `MimiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from MimiAttention.forward
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "MimiModel is using MimiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = causal_mask is None and q_len > 1
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None
+
+
+MIMI_ATTENTION_CLASSES = {
+ "eager": MimiAttention,
+ "flash_attention_2": MimiFlashAttention2,
+ "sdpa": MimiSdpaAttention,
+}
+
+
+class MimiTransformerLayer(GradientCheckpointingLayer):
+ def __init__(self, config: MimiConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+ self.mlp = MimiMLP(config)
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
+ self.self_attn_layer_scale = MimiLayerScale(config)
+ self.mlp_layer_scale = MimiLayerScale(config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + self.self_attn_layer_scale(hidden_states)
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + self.mlp_layer_scale(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+class MimiTransformerModel(nn.Module):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MimiTransformerLayer`]
+
+ Args:
+ config: MimiConfig
+ """
+
+ def __init__(self, config: MimiConfig):
+ super().__init__()
+
+ self.layers = nn.ModuleList(
+ [MimiTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+
+ self.gradient_checkpointing = False
+ self.config = config
+
+ def forward(
+ self,
+ hidden_states: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Embedded representation that will be contextualized by the model
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache`, *optional*):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=hidden_states,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class MimiDecoder(nn.Module):
+ """SEANet decoder as used by Mimi."""
+
+ def __init__(self, config: MimiConfig):
+ super().__init__()
+ scaling = int(2 ** len(config.upsampling_ratios))
+ model = [MimiConv1d(config, config.hidden_size, scaling * config.num_filters, config.kernel_size)]
+
+ # Upsample to raw audio scale
+ for ratio in config.upsampling_ratios:
+ current_scale = scaling * config.num_filters
+ # Add upsampling layers
+ model += [nn.ELU()]
+ model += [
+ MimiConvTranspose1d(config, current_scale, current_scale // 2, kernel_size=ratio * 2, stride=ratio)
+ ]
+ # Add residual layers
+ for j in range(config.num_residual_layers):
+ model += [MimiResnetBlock(config, current_scale // 2, (config.dilation_growth_rate**j, 1))]
+ scaling //= 2
+
+ # Add final layers
+ model += [nn.ELU()]
+ model += [MimiConv1d(config, config.num_filters, config.audio_channels, config.last_kernel_size)]
+ self.layers = nn.ModuleList(model)
+
+ # Copied from transformers.models.encodec.modeling_encodec.EncodecDecoder.forward
+ def forward(self, hidden_states):
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+
+class MimiEuclideanCodebook(nn.Module):
+ """Codebook with Euclidean distance."""
+
+ def __init__(self, config: MimiConfig, epsilon: float = 1e-5):
+ super().__init__()
+ embed = torch.zeros(config.codebook_size, config.codebook_dim)
+
+ self.codebook_size = config.codebook_size
+
+ self.register_buffer("initialized", torch.tensor([True], dtype=torch.float32))
+ self.register_buffer("cluster_usage", torch.ones(config.codebook_size))
+ self.register_buffer("embed_sum", embed)
+ self._embed = None
+ self.epsilon = epsilon
+
+ @property
+ def embed(self) -> torch.Tensor:
+ if self._embed is None:
+ self._embed = self.embed_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None]
+ return self._embed
+
+ def quantize(self, hidden_states):
+ # Projects each vector in `hidden_states` over the nearest centroid and return its index.
+ # `hidden_states` should be `[N, D]` with `N` the number of input vectors and `D` the dimension.
+ dists = torch.cdist(hidden_states[None].float(), self.embed[None].float(), p=2)[0]
+ embed_ind = dists.argmin(dim=-1)
+ return embed_ind
+
+ # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.encode
+ def encode(self, hidden_states):
+ shape = hidden_states.shape
+ # pre-process
+ hidden_states = hidden_states.reshape((-1, shape[-1]))
+ # quantize
+ embed_ind = self.quantize(hidden_states)
+ # post-process
+ embed_ind = embed_ind.view(*shape[:-1])
+ return embed_ind
+
+ # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.decode
+ def decode(self, embed_ind):
+ quantize = nn.functional.embedding(embed_ind, self.embed)
+ return quantize
+
+
+# Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization with Encodec->Mimi
+class MimiVectorQuantization(nn.Module):
+ """
+ Vector quantization implementation. Currently supports only euclidean distance.
+ """
+
+ def __init__(self, config: MimiConfig):
+ super().__init__()
+ self.codebook = MimiEuclideanCodebook(config)
+
+ def encode(self, hidden_states):
+ hidden_states = hidden_states.permute(0, 2, 1)
+ embed_in = self.codebook.encode(hidden_states)
+ return embed_in
+
+ def decode(self, embed_ind):
+ quantize = self.codebook.decode(embed_ind)
+ quantize = quantize.permute(0, 2, 1)
+ return quantize
+
+
+class MimiResidualVectorQuantizer(nn.Module):
+ """Residual Vector Quantizer."""
+
+ def __init__(self, config: MimiConfig, num_quantizers: Optional[int] = None):
+ super().__init__()
+ self.codebook_size = config.codebook_size
+ self.frame_rate = config.frame_rate
+ self.num_quantizers = num_quantizers if num_quantizers is not None else config.num_quantizers
+ self.layers = nn.ModuleList([MimiVectorQuantization(config) for _ in range(self.num_quantizers)])
+
+ self.input_proj = None
+ self.output_proj = None
+ if config.vector_quantization_hidden_dimension != config.hidden_size:
+ self.input_proj = torch.nn.Conv1d(
+ config.hidden_size, config.vector_quantization_hidden_dimension, 1, bias=False
+ )
+ self.output_proj = torch.nn.Conv1d(
+ config.vector_quantization_hidden_dimension, config.hidden_size, 1, bias=False
+ )
+
+ def encode(self, embeddings: torch.Tensor, num_quantizers: Optional[int] = None) -> torch.Tensor:
+ """
+ Encode a given input tensor with the specified frame rate at the given number of quantizers / codebooks. The RVQ encode method sets
+ the appropriate number of quantizers to use and returns indices for each quantizer.
+ """
+ if self.input_proj is not None:
+ embeddings = self.input_proj(embeddings)
+
+ num_quantizers = num_quantizers if num_quantizers is not None else self.num_quantizers
+
+ residual = embeddings
+ all_indices = []
+ for layer in self.layers[:num_quantizers]:
+ indices = layer.encode(residual)
+ quantized = layer.decode(indices)
+ residual = residual - quantized
+ all_indices.append(indices)
+ out_indices = torch.stack(all_indices)
+ return out_indices
+
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
+ """Decode the given codes of shape [B, K, T] to the quantized representation."""
+ quantized_out = torch.tensor(0.0, device=codes.device)
+ codes = codes.transpose(0, 1)
+ for i, indices in enumerate(codes):
+ layer = self.layers[i]
+ quantized = layer.decode(indices)
+ quantized_out = quantized_out + quantized
+
+ if self.output_proj is not None:
+ quantized_out = self.output_proj(quantized_out)
+ return quantized_out
+
+
+class MimiSplitResidualVectorQuantizer(nn.Module):
+ """Split Residual Vector Quantizer."""
+
+ def __init__(self, config: MimiConfig):
+ super().__init__()
+ self.codebook_size = config.codebook_size
+ self.frame_rate = config.frame_rate
+ self.max_num_quantizers = config.num_quantizers
+
+ self.num_semantic_quantizers = config.num_semantic_quantizers
+ self.num_acoustic_quantizers = config.num_quantizers - config.num_semantic_quantizers
+
+ self.semantic_residual_vector_quantizer = MimiResidualVectorQuantizer(config, self.num_semantic_quantizers)
+ self.acoustic_residual_vector_quantizer = MimiResidualVectorQuantizer(config, self.num_acoustic_quantizers)
+
+ def encode(self, embeddings: torch.Tensor, num_quantizers: Optional[float] = None) -> torch.Tensor:
+ """
+ Encode a given input tensor with the specified frame rate at the given number of quantizers / codebooks. The RVQ encode method sets
+ the appropriate number of quantizers to use and returns indices for each quantizer.
+ """
+
+ num_quantizers = self.max_num_quantizers if num_quantizers is None else num_quantizers
+
+ if num_quantizers > self.max_num_quantizers:
+ raise ValueError(
+ f"The number of quantizers (i.e codebooks) asked should be lower than the total number of quantizers {self.max_num_quantizers}, but is currently {num_quantizers}."
+ )
+
+ if num_quantizers < self.num_semantic_quantizers:
+ raise ValueError(
+ f"The number of quantizers (i.e codebooks) asked should be higher than the number of semantic quantizers {self.num_semantic_quantizers}, but is currently {num_quantizers}."
+ )
+
+ # codes is [K, B, T], with T frames, K nb of codebooks.
+ codes = self.semantic_residual_vector_quantizer.encode(embeddings)
+
+ if num_quantizers > self.num_semantic_quantizers:
+ acoustic_codes = self.acoustic_residual_vector_quantizer.encode(
+ embeddings, num_quantizers=num_quantizers - self.num_semantic_quantizers
+ )
+ codes = torch.cat([codes, acoustic_codes], dim=0)
+
+ return codes
+
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
+ """Decode the given codes to the quantized representation."""
+
+ # The first num_semantic_quantizers codebooks are decoded using the semantic RVQ
+ quantized_out = self.semantic_residual_vector_quantizer.decode(codes[:, : self.num_semantic_quantizers])
+
+ # The rest of the codebooks are decoded using the acoustic RVQ
+ if codes.shape[1] > self.num_semantic_quantizers:
+ quantized_out += self.acoustic_residual_vector_quantizer.decode(codes[:, self.num_semantic_quantizers :])
+ return quantized_out
+
+
+@auto_docstring
+class MimiPreTrainedModel(PreTrainedModel):
+ config: MimiConfig
+ base_model_prefix = "mimi"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["MimiDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
+ nn.init.kaiming_normal_(module.weight)
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+ elif isinstance(module, MimiLayerScale):
+ module.scale.data.fill_(self.config.layer_scale_initial_scale)
+
+
+@auto_docstring(
+ custom_intro="""
+ The Mimi neural audio codec model.
+ """
+)
+class MimiModel(MimiPreTrainedModel):
+ def __init__(self, config: MimiConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.encoder = MimiEncoder(config)
+ self.encoder_transformer = MimiTransformerModel(config)
+
+ self.downsample = None
+ self.upsample = None
+ if config.frame_rate != config.encodec_frame_rate:
+ self.downsample = MimiConv1d(
+ config,
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate),
+ stride=2,
+ bias=False,
+ pad_mode="replicate",
+ layer_idx=len(self.encoder._mimiconv1d_layer_names),
+ )
+
+ self.upsample = MimiConvTranspose1d(
+ config,
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate),
+ stride=2,
+ bias=False,
+ groups=config.upsample_groups,
+ )
+
+ self.decoder_transformer = MimiTransformerModel(config)
+ self.decoder = MimiDecoder(config)
+
+ self.quantizer = MimiSplitResidualVectorQuantizer(config)
+
+ self.bits_per_codebook = int(math.log2(self.config.codebook_size))
+ if 2**self.bits_per_codebook != self.config.codebook_size:
+ raise ValueError("The codebook_size must be a power of 2.")
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.encoder
+
+ def _encode_frame(
+ self,
+ input_values: torch.Tensor,
+ num_quantizers: int,
+ padding_mask: int,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ padding_cache: Optional[MimiConv1dPaddingCache] = None,
+ return_dict: Optional[bool] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Encodes the given input using the underlying VQVAE. The padding mask is required to compute the correct scale.
+ """
+
+ # TODO: @eustlb, let's make the encoder support padding_mask so that batched inputs are supported.
+ embeddings = self.encoder(input_values, padding_cache=padding_cache)
+
+ # TODO: @eustlb, convert the padding mask to attention mask.
+ encoder_outputs = self.encoder_transformer(
+ embeddings.transpose(1, 2), past_key_values=past_key_values, return_dict=return_dict
+ )
+ if return_dict:
+ past_key_values = encoder_outputs.get("past_key_values")
+ elif len(encoder_outputs) > 1:
+ past_key_values = encoder_outputs[1]
+ embeddings = encoder_outputs[0].transpose(1, 2)
+ embeddings = self.downsample(embeddings, padding_cache=padding_cache)
+
+ codes = self.quantizer.encode(embeddings, num_quantizers)
+ codes = codes.transpose(0, 1)
+ return codes, past_key_values, padding_cache
+
+ def get_encoded_length(self, input_length: torch.LongTensor) -> torch.LongTensor:
+ """
+ Return the number of frames of the encoded audio waveform.
+ """
+ output_length = input_length
+
+ # encoder
+ for layer_name in self.encoder._mimiconv1d_layer_names:
+ output_length = self.encoder.get_submodule(layer_name)._get_output_length(output_length)
+
+ # downsample
+ output_length = self.downsample._get_output_length(output_length)
+
+ return output_length
+
+ def get_audio_codes_mask(self, padding_mask: torch.Tensor, padding_side: str = "right"):
+ """
+ Get the mask for the audio codes from the original padding mask.
+ """
+ encoded_lengths = self.get_encoded_length(padding_mask.sum(dim=-1))
+
+ audio_codes_mask = torch.arange(encoded_lengths.max(), device=encoded_lengths.device).expand(
+ len(encoded_lengths), -1
+ )
+ audio_codes_mask = audio_codes_mask < encoded_lengths.unsqueeze(1)
+ audio_codes_mask = audio_codes_mask.to(padding_mask.device)
+
+ if padding_side == "right":
+ return audio_codes_mask
+ else:
+ return audio_codes_mask.flip(dims=[-1])
+
+ def encode(
+ self,
+ input_values: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ num_quantizers: Optional[float] = None,
+ encoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ padding_cache: Optional[MimiConv1dPaddingCache] = None,
+ use_streaming: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], MimiEncoderOutput]:
+ """
+ Encodes the input audio waveform into discrete codes.
+
+ Args:
+ input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
+ Float values of the input audio waveform.
+ padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
+ Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
+ for *masked*.
+ num_quantizers (`int`, *optional*):
+ Number of quantizers (i.e codebooks) to use. By default, all quantizers are used.
+ encoder_past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer.
+ This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ The model will output the same cache format that is fed as input.
+
+ If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+ have their past key value states given to this model).
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+ Returns:
+ `codebook` of shape `[batch_size, num_codebooks, frames]`, the discrete encoded codes for the input audio waveform.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+ use_streaming = use_streaming if use_streaming is not None else self.config.use_streaming
+
+ num_quantizers = self.config.num_quantizers if num_quantizers is None else num_quantizers
+
+ if num_quantizers > self.config.num_quantizers:
+ raise ValueError(
+ f"The number of quantizers (i.e codebooks) asked should be lower than the total number of quantizers {self.config.num_quantizers}, but is currently {num_quantizers}."
+ )
+
+ _, channels, input_length = input_values.shape
+
+ if channels < 1 or channels > 2:
+ raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")
+
+ if padding_mask is None:
+ padding_mask = torch.ones_like(input_values).bool()
+
+ if use_streaming and padding_cache is None:
+ per_layer_padding, per_layer_padding_mode, per_layer_in_channels = [], [], []
+ for layer_name in self.encoder._mimiconv1d_layer_names:
+ per_layer_padding.append(self.encoder.get_submodule(layer_name).padding_total)
+ per_layer_padding_mode.append(self.encoder.get_submodule(layer_name).pad_mode)
+ per_layer_in_channels.append(self.encoder.get_submodule(layer_name).in_channels)
+
+ # downsample layer
+ per_layer_padding.append(self.downsample.padding_total)
+ per_layer_padding_mode.append(self.downsample.pad_mode)
+ per_layer_in_channels.append(self.downsample.in_channels)
+
+ padding_cache = MimiConv1dPaddingCache(
+ num_layers=len(self.encoder._mimiconv1d_layer_names) + 1,
+ per_layer_padding=per_layer_padding,
+ per_layer_padding_mode=per_layer_padding_mode,
+ per_layer_in_channels=per_layer_in_channels,
+ )
+
+ encoded_frames, encoder_past_key_values, padding_cache = self._encode_frame(
+ input_values,
+ num_quantizers,
+ padding_mask.bool(),
+ past_key_values=encoder_past_key_values,
+ padding_cache=padding_cache,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return (
+ encoded_frames,
+ encoder_past_key_values,
+ padding_cache,
+ )
+
+ return MimiEncoderOutput(encoded_frames, encoder_past_key_values, padding_cache)
+
+ def _decode_frame(
+ self,
+ codes: torch.Tensor,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.Tensor:
+ embeddings = self.quantizer.decode(codes)
+
+ embeddings = self.upsample(embeddings)
+ decoder_outputs = self.decoder_transformer(
+ embeddings.transpose(1, 2), past_key_values=past_key_values, return_dict=return_dict
+ )
+ if return_dict:
+ past_key_values = decoder_outputs.get("past_key_values")
+ elif len(decoder_outputs) > 1:
+ past_key_values = decoder_outputs[1]
+ embeddings = decoder_outputs[0].transpose(1, 2)
+ outputs = self.decoder(embeddings)
+ return outputs, past_key_values
+
+ def decode(
+ self,
+ audio_codes: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ decoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor, torch.Tensor], MimiDecoderOutput]:
+ """
+ Decodes the given frames into an output audio waveform.
+
+ Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
+ trimmed.
+
+ Args:
+ audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
+ Discret code embeddings computed using `model.encode`.
+ padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
+ Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
+ for *masked*.
+ decoder_past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer.
+ This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ The model will output the same cache format that is fed as input.
+
+ If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+ have their past key value states given to this model).
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+ """
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ audio_values, decoder_past_key_values = self._decode_frame(
+ audio_codes, past_key_values=decoder_past_key_values, return_dict=return_dict
+ )
+
+ # truncate based on padding mask
+ if padding_mask is not None and padding_mask.shape[-1] < audio_values.shape[-1]:
+ audio_values = audio_values[..., : padding_mask.shape[-1]]
+
+ if not return_dict:
+ return (
+ audio_values,
+ decoder_past_key_values,
+ )
+ return MimiDecoderOutput(audio_values, decoder_past_key_values)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_values: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ num_quantizers: Optional[int] = None,
+ audio_codes: Optional[torch.Tensor] = None,
+ encoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ decoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor, torch.Tensor], MimiOutput]:
+ r"""
+ input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
+ Raw audio input converted to Float.
+ padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
+ for *masked*.
+ num_quantizers (`int`, *optional*):
+ Number of quantizers (i.e codebooks) to use. By default, all quantizers are used.
+ audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
+ Discret code embeddings computed using `model.encode`.
+ encoder_past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer.
+ This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ The model will output the same cache format that is fed as input.
+
+ If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+ have their past key value states given to this model).
+ decoder_past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer.
+ This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ The model will output the same cache format that is fed as input.
+
+ If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+ have their past key value states given to this model).
+
+ Examples:
+
+ ```python
+ >>> from datasets import load_dataset
+ >>> from transformers import AutoFeatureExtractor, MimiModel
+
+ >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
+ >>> audio_sample = dataset["train"]["audio"][0]["array"]
+
+ >>> model_id = "kyutai/mimi"
+ >>> model = MimiModel.from_pretrained(model_id)
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
+
+ >>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> audio_codes = outputs.audio_codes
+ >>> audio_values = outputs.audio_values
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if padding_mask is None:
+ padding_mask = torch.ones_like(input_values).bool()
+
+ if audio_codes is None:
+ encoder_outputs = self.encode(
+ input_values, padding_mask, num_quantizers, encoder_past_key_values, return_dict=return_dict
+ )
+ audio_codes = encoder_outputs[0]
+ if return_dict:
+ encoder_past_key_values = encoder_outputs.get("past_key_values")
+ elif len(encoder_outputs) > 1:
+ encoder_past_key_values = encoder_outputs[1]
+
+ decoder_outputs = self.decode(audio_codes, padding_mask, decoder_past_key_values, return_dict=return_dict)
+ audio_values = decoder_outputs[0]
+ if return_dict:
+ decoder_past_key_values = decoder_outputs.get("past_key_values")
+ elif len(decoder_outputs) > 1:
+ decoder_past_key_values = decoder_outputs[1]
+
+ if not return_dict:
+ return (audio_codes, audio_values, encoder_past_key_values, decoder_past_key_values)
+
+ return MimiOutput(
+ audio_codes=audio_codes,
+ audio_values=audio_values,
+ encoder_past_key_values=encoder_past_key_values,
+ decoder_past_key_values=decoder_past_key_values,
+ )
+
+
+__all__ = ["MimiModel", "MimiPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/minimax/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/minimax/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..91834eb6a2246ff1a5790116c94f3b0b4a559725
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/minimax/__init__.py
@@ -0,0 +1,29 @@
+# coding=utf-8
+# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_minimax import *
+ from .modeling_minimax import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/minimax/configuration_minimax.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/minimax/configuration_minimax.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ab46efb2cf822548242ffce2fdb122dbce8b6b8
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/minimax/configuration_minimax.py
@@ -0,0 +1,230 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/minimax/modular_minimax.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_minimax.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+
+
+class MiniMaxConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MiniMaxModel`]. It is used to instantiate an
+ MiniMax model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the MiniMax.
+
+ [MiniMaxAI/MiniMax-Text-01-hf](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the MiniMax model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`MiniMaxModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 14336):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
+ head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
+ The attention head dimension.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
+ The maximum sequence length that this model might ever be used with. MiniMax's sliding window attention
+ allows sequence of up to 4096*32 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ The id of the padding token.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the "beginning-of-sequence" token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the "end-of-sequence" token.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ sliding_window (`int`, *optional*):
+ Sliding window attention window size. If not specified, will default to `4096`.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
+ parameter
+ num_local_experts (`int`, *optional*, defaults to 8):
+ Number of experts per Sparse MLP layer.
+ output_router_logits (`bool`, *optional*, defaults to `False`):
+ Whether or not the router logits should be returned by the model. Enabling this will also
+ allow the model to output the auxiliary loss. See [here]() for more details
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
+ The aux loss factor for the total loss.
+ router_jitter_noise (`float`, *optional*, defaults to 0.0):
+ Amount of noise to add to the router.
+ layer_types (`list`, *optional*):
+ Attention pattern for each layer.
+ block_size (`int`, *optional*, defaults to 256):
+ The length of each attention block, determining how queries, keys, and values
+ are grouped and processed for intra- and inter-block attention.
+ full_attn_alpha_factor (`float`, *optional*, defaults to 1):
+ Weight for residual value in residual connection after normal attention.
+ full_attn_beta_factor (`float`, *optional*, defaults to 1):
+ Weight for hidden state value in residual connection after normal attention.
+ linear_attn_alpha_factor (`float`, *optional*, defaults to 1):
+ Weight for residual value in residual connection after lightning attention.
+ linear_attn_beta_factor (`float`, *optional*, defaults to 1):
+ Weight for hidden state value in residual connection after lightning attention.
+ mlp_alpha_factor (`float`, *optional*, defaults to 1):
+ Weight for residual value in residual connection after MLP.
+ mlp_beta_factor (`float`, *optional*, defaults to 1):
+ Weight for hidden state value in residual connection after MLP.
+
+ ```python
+ >>> from transformers import MiniMaxModel, MiniMaxConfig
+
+ >>> # Initializing a MiniMax style configuration
+ >>> configuration = MiniMaxConfig()
+
+ >>> # Initializing a model from the MiniMax style configuration
+ >>> model = MiniMaxModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "minimax"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts
+ "layers.*.block_sparse_moe.experts.*.w1": "colwise",
+ "layers.*.block_sparse_moe.experts.*.w2": "rowwise",
+ "layers.*.block_sparse_moe.experts.*.w3": "colwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size=14336,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=8,
+ head_dim=None,
+ hidden_act="silu",
+ max_position_embeddings=4096 * 32,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=1e6,
+ sliding_window=None,
+ attention_dropout=0.0,
+ num_experts_per_tok=2,
+ num_local_experts=8,
+ output_router_logits=False,
+ router_aux_loss_coef=0.001,
+ router_jitter_noise=0.0,
+ layer_types=None,
+ block_size=256,
+ full_attn_alpha_factor=1,
+ full_attn_beta_factor=1,
+ linear_attn_alpha_factor=1,
+ linear_attn_beta_factor=1,
+ mlp_alpha_factor=1,
+ mlp_beta_factor=1,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.sliding_window = sliding_window
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+ self.head_dim = head_dim
+
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_local_experts = num_local_experts
+ self.output_router_logits = output_router_logits
+ self.router_aux_loss_coef = router_aux_loss_coef
+ self.router_jitter_noise = router_jitter_noise
+ self.layer_types = layer_types
+ self.block_size = block_size
+ self.full_attn_alpha_factor = full_attn_alpha_factor
+ self.full_attn_beta_factor = full_attn_beta_factor
+ self.linear_attn_alpha_factor = linear_attn_alpha_factor
+ self.linear_attn_beta_factor = linear_attn_beta_factor
+ self.mlp_alpha_factor = mlp_alpha_factor
+ self.mlp_beta_factor = mlp_beta_factor
+
+ if self.layer_types is None:
+ self.layer_types = [
+ "full_attention" if bool((i + 1) % 2) else "linear_attention" for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+
+__all__ = ["MiniMaxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/minimax/modeling_minimax.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/minimax/modeling_minimax.py
new file mode 100644
index 0000000000000000000000000000000000000000..df01bbd0d39ebe84d01ce6ebb05a12fe84e01fbd
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/minimax/modeling_minimax.py
@@ -0,0 +1,930 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/minimax/modular_minimax.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_minimax.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import (
+ GenericForQuestionAnswering,
+ GenericForSequenceClassification,
+ GenericForTokenClassification,
+ GradientCheckpointingLayer,
+)
+from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import OutputRecorder, check_model_inputs
+from .configuration_minimax import MiniMaxConfig
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class MiniMaxRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ MiniMaxRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class MiniMaxCache(DynamicCache):
+ def __init__(self):
+ super().__init__()
+ self.linear_cache: list[torch.Tensor] = []
+
+ def set_linear_cache(self, layer_idx, linear_cache):
+ # There may be skipped layers, fill them with empty lists
+ for _ in range(len(self.linear_cache), layer_idx + 1):
+ self.linear_cache.append([])
+ self.linear_cache[layer_idx] = linear_cache
+
+ def get_linear_cache(self, layer_idx: int):
+ if layer_idx < len(self):
+ return self.linear_cache[layer_idx]
+ return None
+
+ def __len__(self):
+ return max(super().__len__(), len(self.linear_cache))
+
+ def __getitem__(self, layer_idx: int):
+ if layer_idx < len(self.linear_cache) and self.linear_cache[layer_idx] != []:
+ return (self.linear_cache[layer_idx],)
+ return super().__getitem__(layer_idx)
+
+ def __iter__(self):
+ for layer_idx in range(len(self)):
+ yield self[layer_idx]
+
+ def batch_repeat_interleave(self, repeats: int):
+ for layer_idx in range(len(self)):
+ if self.linear_cache[layer_idx] != []:
+ self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0)
+ else:
+ self.layers[layer_idx].batch_repeat_interleave(repeats)
+
+ def batch_select_indices(self, indices: torch.Tensor):
+ for layer_idx in range(len(self)):
+ if self.linear_cache[layer_idx] != []:
+ self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...]
+ else:
+ self.layers[layer_idx].batch_select_indices(indices)
+
+ def crop(self, max_length: int):
+ raise RuntimeError("MiniMaxCache doesnot support `crop` method")
+
+
+class MiniMaxLightningAttention(nn.Module):
+ def __init__(self, config: MiniMaxConfig, layer_idx: int):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ self.num_attention_heads = config.num_attention_heads
+ self.num_hidden_layers = config.num_hidden_layers
+ self.block_size = config.block_size
+
+ self.act_fn = ACT2FN[config.hidden_act]
+ self.norm = MiniMaxRMSNorm(self.head_dim * self.num_attention_heads)
+ self.qkv_proj = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim * 3, bias=False)
+ self.out_proj = nn.Linear(self.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.output_gate = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
+
+ slope_rate = self.get_slope_rate()
+ query_decay, key_decay, diagonal_decay = self.decay_factors(slope_rate)
+
+ self.register_buffer("slope_rate", slope_rate)
+ self.register_buffer("query_decay", query_decay)
+ self.register_buffer("key_decay", key_decay)
+ self.register_buffer("diagonal_decay", diagonal_decay)
+
+ def get_slope_rate(self):
+ base = 1 / (2 ** (8 / self.num_attention_heads))
+ exponent = torch.arange(self.num_attention_heads) + 1
+ factor = 1 - self.layer_idx / (self.num_hidden_layers - 1 + 1e-5) + 1e-5
+
+ rate = base**exponent
+ rate = rate * factor
+ rate = rate[:, None, None]
+
+ return rate
+
+ def decay_factors(self, slope_rate):
+ block_size_range = torch.arange(self.block_size) + 1
+
+ query_decay = torch.exp(-slope_rate * block_size_range[:, None])
+ key_decay = torch.exp(-slope_rate * (self.block_size - block_size_range[:, None]))
+
+ diagonal_decay = block_size_range[:, None] - block_size_range[None, :]
+ diagonal_decay = diagonal_decay[None, None, :, :]
+ diagonal_decay = slope_rate * diagonal_decay
+ diagonal_decay = torch.where(diagonal_decay >= 0, -diagonal_decay, float("-inf"))
+ diagonal_decay = torch.exp(diagonal_decay)
+
+ return query_decay, key_decay, diagonal_decay
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ batch_size, seq_len, hidden_size = hidden_states.shape
+ num_blocks = (seq_len + self.block_size - 1) // self.block_size
+
+ qkv_states = self.act_fn(self.qkv_proj(hidden_states))
+ qkv_states = qkv_states.reshape(batch_size, seq_len, self.num_attention_heads, 3 * self.head_dim)
+
+ query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=3)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # calculated (K.T @ V) and saved as cache
+ attn_weights_inter = None
+ if past_key_values is not None:
+ attn_weights_inter = past_key_values.get_linear_cache(self.layer_idx)
+
+ if attn_weights_inter is None:
+ attn_weights_inter = torch.zeros(batch_size, self.num_attention_heads, self.head_dim, self.head_dim).to(
+ value_states
+ )
+
+ # apply attention_mask
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(dtype=torch.bool) # Ensure it's a boolean tensor
+ value_states = value_states.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(-1), 0)
+
+ attn_output = []
+ for i in range(num_blocks):
+ start_idx = i * self.block_size
+ end_idx = min(start_idx + self.block_size, seq_len)
+ current_block_size = end_idx - start_idx
+
+ current_query_states = query_states[:, :, start_idx:end_idx]
+ current_key_states = key_states[:, :, start_idx:end_idx]
+ current_value_states = value_states[:, :, start_idx:end_idx]
+
+ current_query_decay = self.query_decay[:, :current_block_size]
+ current_key_decay = self.key_decay[:, -current_block_size:]
+ current_diagonal_decay = self.diagonal_decay[:, :, :current_block_size, :current_block_size]
+ block_decay = torch.exp(-self.slope_rate * current_block_size)
+
+ # intra: ( Q @ K.T ) @ V -> QK * V
+ attn_weights_intra = torch.matmul(current_query_states, current_key_states.transpose(-1, -2))
+ attn_output_intra = torch.matmul(attn_weights_intra * current_diagonal_decay, current_value_states)
+
+ # inter: Q @ ( K.T @ V ) -> Q * KV
+ attn_output_inter = torch.matmul(current_query_states * current_query_decay, attn_weights_inter)
+
+ # final attention output
+ current_attn_output = attn_output_inter + attn_output_intra
+ attn_output.append(current_attn_output)
+
+ # calculate attn_weights_inter for next block or cache
+ next_attn_weights_inter = torch.matmul(
+ (current_key_states * current_key_decay).transpose(-1, -2), current_value_states
+ )
+ attn_weights_inter = attn_weights_inter * block_decay + next_attn_weights_inter
+
+ else:
+ ratio = torch.exp(-self.slope_rate)
+ attn_output = []
+ for i in range(seq_len):
+ current_query_states = query_states[:, :, i : i + 1]
+ current_key_states = key_states[:, :, i : i + 1]
+ current_value_states = value_states[:, :, i : i + 1]
+
+ current_attn_weights_inter = torch.matmul(current_key_states.transpose(-1, -2), current_value_states)
+ attn_weights_inter = ratio * attn_weights_inter + current_attn_weights_inter
+ current_attn_output = torch.matmul(current_query_states, attn_weights_inter)
+
+ attn_output.append(current_attn_output)
+
+ # concatenate attention outputs over all blocks
+ attn_output = torch.cat(attn_output, dim=-2)
+
+ # final output projection
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(batch_size, seq_len, self.num_attention_heads * self.head_dim)
+ attn_output = self.norm(attn_output)
+ attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output
+ attn_output = self.out_proj(attn_output)
+
+ # update cache
+ if past_key_values is not None:
+ past_key_values.set_linear_cache(self.layer_idx, attn_weights_inter)
+
+ return attn_output, attn_weights_inter
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class MiniMaxAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: MiniMaxConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class MiniMaxBlockSparseTop2MLP(nn.Module):
+ def __init__(self, config: MiniMaxConfig):
+ super().__init__()
+ self.ffn_dim = config.intermediate_size
+ self.hidden_dim = config.hidden_size
+
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
+
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states):
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
+ current_hidden_states = self.w2(current_hidden_states)
+ return current_hidden_states
+
+
+class MiniMaxSparseMoeBlock(nn.Module):
+ """
+ This implementation is
+ strictly equivalent to standard MoE with full capacity (no
+ dropped tokens). It's faster since it formulates MoE operations
+ in terms of block-sparse operations to accommodate imbalanced
+ assignments of tokens to experts, whereas standard MoE either
+ (1) drop tokens at the cost of reduced performance or (2) set
+ capacity factor to number of experts and thus waste computation
+ and memory on padding.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_dim = config.hidden_size
+ self.ffn_dim = config.intermediate_size
+ self.num_experts = config.num_local_experts
+ self.top_k = config.num_experts_per_tok
+
+ # gating
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
+
+ self.experts = nn.ModuleList([MiniMaxBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
+
+ # Jitter parameters
+ self.jitter_noise = config.router_jitter_noise
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """ """
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ if self.training and self.jitter_noise > 0:
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ # router_logits: (batch * sequence_length, n_experts)
+ router_logits = self.gate(hidden_states)
+
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ # we cast back to the input dtype
+ routing_weights = routing_weights.to(hidden_states.dtype)
+
+ final_hidden_states = torch.zeros(
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ # One hot encode the selected experts to create an expert mask
+ # this will be used to easily index which expert is going to be sollicitated
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+ for expert_idx in expert_hit:
+ expert_layer = self.experts[expert_idx]
+ idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
+ # Index the correct hidden states and compute the expert hidden state for
+ # the current expert. We need to make sure to multiply the output hidden
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
+
+ # However `index_add_` only support torch tensors for indexing so we'll use
+ # the `top_x` tensor here.
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
+ return final_hidden_states, router_logits
+
+
+class MiniMaxDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: MiniMaxConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = MiniMaxAttention(config, layer_idx)
+
+ self.block_sparse_moe = MiniMaxSparseMoeBlock(config)
+ self.input_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.layer_idx = layer_idx
+ self.layer_type = config.layer_types[layer_idx]
+ self.mlp_alpha_factor = config.mlp_alpha_factor
+ self.mlp_beta_factor = config.mlp_beta_factor
+
+ if self.layer_type == "linear_attention":
+ self.self_attn = MiniMaxLightningAttention(config, layer_idx)
+ self.attn_alpha_factor = config.linear_attn_alpha_factor
+ self.attn_beta_factor = config.linear_attn_beta_factor
+ else:
+ self.self_attn = MiniMaxAttention(config, layer_idx)
+ self.attn_alpha_factor = config.full_attn_alpha_factor
+ self.attn_beta_factor = config.full_attn_beta_factor
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ output_router_logits: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ attention_mask (`torch.Tensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_router_logits (`bool`, *optional*):
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
+ should not be returned during inference.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
+ hidden_states = self.input_layernorm(hidden_states)
+ residual = hidden_states
+
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor
+
+ # Fully Connected
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ residual = hidden_states
+ hidden_states, _ = self.block_sparse_moe(hidden_states)
+ hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor
+
+ return hidden_states
+
+
+@auto_docstring
+class MiniMaxPreTrainedModel(PreTrainedModel):
+ config: MiniMaxConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["MiniMaxDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _can_compile_fullgraph = False
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "router_logits": OutputRecorder(MiniMaxSparseMoeBlock, index=1),
+ "hidden_states": MiniMaxDecoderLayer,
+ "attentions": [MiniMaxAttention, MiniMaxLightningAttention],
+ }
+
+
+class MiniMaxRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: MiniMaxConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@auto_docstring
+class MiniMaxModel(MiniMaxPreTrainedModel):
+ def __init__(self, config: MiniMaxConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [MiniMaxDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = MiniMaxRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[MiniMaxCache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if use_cache and past_key_values is None:
+ past_key_values = MiniMaxCache()
+ elif use_cache and not isinstance(past_key_values, MiniMaxCache):
+ raise ValueError(
+ f"MiniMax uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}."
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
+ causal_mask = mask_function(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers:
+ if decoder_layer.layer_type == "full_attention":
+ input_attention_mask = causal_mask
+ else:
+ # lightning attention uses original attention_mask, and uses it only for the first step
+ input_attention_mask = attention_mask
+
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=input_attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+def load_balancing_loss_func(
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
+ num_experts: Optional[int] = None,
+ top_k=2,
+ attention_mask: Optional[torch.Tensor] = None,
+) -> Union[torch.Tensor, int]:
+ r"""
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+
+ Args:
+ gate_logits:
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+ shape [batch_size X sequence_length, num_experts].
+ num_experts:
+ Number of experts
+ top_k:
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
+ parameter.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention_mask used in forward function
+ shape [batch_size X sequence_length] if not None.
+
+ Returns:
+ The auxiliary loss.
+ """
+ if gate_logits is None or not isinstance(gate_logits, tuple):
+ return 0
+
+ if isinstance(gate_logits, tuple):
+ compute_device = gate_logits[0].device
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
+
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
+
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+ if attention_mask is None:
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
+ else:
+ batch_size, sequence_length = attention_mask.shape
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+ expert_attention_mask = (
+ attention_mask[None, :, :, None, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+ .reshape(-1, top_k, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+ expert_attention_mask, dim=0
+ )
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+ router_per_expert_attention_mask = (
+ attention_mask[None, :, :, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
+ .reshape(-1, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+ router_per_expert_attention_mask, dim=0
+ )
+
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ return overall_loss * num_experts
+
+
+@auto_docstring
+class MiniMaxForCausalLM(MiniMaxPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = MiniMaxModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.router_aux_loss_coef = config.router_aux_loss_coef
+ self.num_experts = config.num_local_experts
+ self.num_experts_per_tok = config.num_experts_per_tok
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MiniMaxForCausalLM
+
+ >>> model = MiniMaxForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: MoeModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_router_logits=output_router_logits,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits,
+ self.num_experts,
+ self.num_experts_per_tok,
+ attention_mask,
+ )
+ if labels is not None:
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
+
+
+class MiniMaxForSequenceClassification(GenericForSequenceClassification, MiniMaxPreTrainedModel):
+ pass
+
+
+class MiniMaxForTokenClassification(GenericForTokenClassification, MiniMaxPreTrainedModel):
+ pass
+
+
+class MiniMaxForQuestionAnswering(GenericForQuestionAnswering, MiniMaxPreTrainedModel):
+ pass
+
+
+__all__ = [
+ "MiniMaxPreTrainedModel",
+ "MiniMaxModel",
+ "MiniMaxForCausalLM",
+ "MiniMaxForSequenceClassification",
+ "MiniMaxForTokenClassification",
+ "MiniMaxForQuestionAnswering",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/minimax/modular_minimax.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/minimax/modular_minimax.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb0777f08cefd2028687c6e17c555529ebe45629
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/minimax/modular_minimax.py
@@ -0,0 +1,605 @@
+# coding=utf-8
+# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch MiniMax model."""
+
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...configuration_utils import layer_type_validation
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import MoeModelOutputWithPast
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, logging
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import OutputRecorder, check_model_inputs
+from ..mixtral.configuration_mixtral import MixtralConfig
+from ..mixtral.modeling_mixtral import (
+ MixtralAttention,
+ MixtralDecoderLayer,
+ MixtralForCausalLM,
+ MixtralForQuestionAnswering,
+ MixtralForSequenceClassification,
+ MixtralForTokenClassification,
+ MixtralModel,
+ MixtralPreTrainedModel,
+ MixtralRMSNorm,
+ MixtralSparseMoeBlock,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class MiniMaxConfig(MixtralConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MiniMaxModel`]. It is used to instantiate an
+ MiniMax model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the MiniMax.
+
+ [MiniMaxAI/MiniMax-Text-01-hf](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the MiniMax model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`MiniMaxModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 14336):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
+ head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
+ The attention head dimension.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
+ The maximum sequence length that this model might ever be used with. MiniMax's sliding window attention
+ allows sequence of up to 4096*32 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ The id of the padding token.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the "beginning-of-sequence" token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the "end-of-sequence" token.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ sliding_window (`int`, *optional*):
+ Sliding window attention window size. If not specified, will default to `4096`.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
+ parameter
+ num_local_experts (`int`, *optional*, defaults to 8):
+ Number of experts per Sparse MLP layer.
+ output_router_logits (`bool`, *optional*, defaults to `False`):
+ Whether or not the router logits should be returned by the model. Enabling this will also
+ allow the model to output the auxiliary loss. See [here]() for more details
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
+ The aux loss factor for the total loss.
+ router_jitter_noise (`float`, *optional*, defaults to 0.0):
+ Amount of noise to add to the router.
+ layer_types (`list`, *optional*):
+ Attention pattern for each layer.
+ block_size (`int`, *optional*, defaults to 256):
+ The length of each attention block, determining how queries, keys, and values
+ are grouped and processed for intra- and inter-block attention.
+ full_attn_alpha_factor (`float`, *optional*, defaults to 1):
+ Weight for residual value in residual connection after normal attention.
+ full_attn_beta_factor (`float`, *optional*, defaults to 1):
+ Weight for hidden state value in residual connection after normal attention.
+ linear_attn_alpha_factor (`float`, *optional*, defaults to 1):
+ Weight for residual value in residual connection after lightning attention.
+ linear_attn_beta_factor (`float`, *optional*, defaults to 1):
+ Weight for hidden state value in residual connection after lightning attention.
+ mlp_alpha_factor (`float`, *optional*, defaults to 1):
+ Weight for residual value in residual connection after MLP.
+ mlp_beta_factor (`float`, *optional*, defaults to 1):
+ Weight for hidden state value in residual connection after MLP.
+
+ ```python
+ >>> from transformers import MiniMaxModel, MiniMaxConfig
+
+ >>> # Initializing a MiniMax style configuration
+ >>> configuration = MiniMaxConfig()
+
+ >>> # Initializing a model from the MiniMax style configuration
+ >>> model = MiniMaxModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ def __init__(
+ self,
+ layer_types=None,
+ block_size=256,
+ full_attn_alpha_factor=1,
+ full_attn_beta_factor=1,
+ linear_attn_alpha_factor=1,
+ linear_attn_beta_factor=1,
+ mlp_alpha_factor=1,
+ mlp_beta_factor=1,
+ **super_kwargs,
+ ):
+ super().__init__(**super_kwargs)
+ self.layer_types = layer_types
+ self.block_size = block_size
+ self.full_attn_alpha_factor = full_attn_alpha_factor
+ self.full_attn_beta_factor = full_attn_beta_factor
+ self.linear_attn_alpha_factor = linear_attn_alpha_factor
+ self.linear_attn_beta_factor = linear_attn_beta_factor
+ self.mlp_alpha_factor = mlp_alpha_factor
+ self.mlp_beta_factor = mlp_beta_factor
+
+ if self.layer_types is None:
+ self.layer_types = [
+ "full_attention" if bool((i + 1) % 2) else "linear_attention" for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+
+class MiniMaxRMSNorm(MixtralRMSNorm):
+ pass
+
+
+class MiniMaxCache(DynamicCache):
+ def __init__(self):
+ super().__init__()
+ self.linear_cache: list[torch.Tensor] = []
+
+ def set_linear_cache(self, layer_idx, linear_cache):
+ # There may be skipped layers, fill them with empty lists
+ for _ in range(len(self.linear_cache), layer_idx + 1):
+ self.linear_cache.append([])
+ self.linear_cache[layer_idx] = linear_cache
+
+ def get_linear_cache(self, layer_idx: int):
+ if layer_idx < len(self):
+ return self.linear_cache[layer_idx]
+ return None
+
+ def __len__(self):
+ return max(super().__len__(), len(self.linear_cache))
+
+ def __getitem__(self, layer_idx: int):
+ if layer_idx < len(self.linear_cache) and self.linear_cache[layer_idx] != []:
+ return (self.linear_cache[layer_idx],)
+ return super().__getitem__(layer_idx)
+
+ def __iter__(self):
+ for layer_idx in range(len(self)):
+ yield self[layer_idx]
+
+ def batch_repeat_interleave(self, repeats: int):
+ for layer_idx in range(len(self)):
+ if self.linear_cache[layer_idx] != []:
+ self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0)
+ else:
+ self.layers[layer_idx].batch_repeat_interleave(repeats)
+
+ def batch_select_indices(self, indices: torch.Tensor):
+ for layer_idx in range(len(self)):
+ if self.linear_cache[layer_idx] != []:
+ self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...]
+ else:
+ self.layers[layer_idx].batch_select_indices(indices)
+
+ def crop(self, max_length: int):
+ raise RuntimeError("MiniMaxCache doesnot support `crop` method")
+
+
+class MiniMaxLightningAttention(nn.Module):
+ def __init__(self, config: MiniMaxConfig, layer_idx: int):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ self.num_attention_heads = config.num_attention_heads
+ self.num_hidden_layers = config.num_hidden_layers
+ self.block_size = config.block_size
+
+ self.act_fn = ACT2FN[config.hidden_act]
+ self.norm = MiniMaxRMSNorm(self.head_dim * self.num_attention_heads)
+ self.qkv_proj = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim * 3, bias=False)
+ self.out_proj = nn.Linear(self.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.output_gate = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
+
+ slope_rate = self.get_slope_rate()
+ query_decay, key_decay, diagonal_decay = self.decay_factors(slope_rate)
+
+ self.register_buffer("slope_rate", slope_rate)
+ self.register_buffer("query_decay", query_decay)
+ self.register_buffer("key_decay", key_decay)
+ self.register_buffer("diagonal_decay", diagonal_decay)
+
+ def get_slope_rate(self):
+ base = 1 / (2 ** (8 / self.num_attention_heads))
+ exponent = torch.arange(self.num_attention_heads) + 1
+ factor = 1 - self.layer_idx / (self.num_hidden_layers - 1 + 1e-5) + 1e-5
+
+ rate = base**exponent
+ rate = rate * factor
+ rate = rate[:, None, None]
+
+ return rate
+
+ def decay_factors(self, slope_rate):
+ block_size_range = torch.arange(self.block_size) + 1
+
+ query_decay = torch.exp(-slope_rate * block_size_range[:, None])
+ key_decay = torch.exp(-slope_rate * (self.block_size - block_size_range[:, None]))
+
+ diagonal_decay = block_size_range[:, None] - block_size_range[None, :]
+ diagonal_decay = diagonal_decay[None, None, :, :]
+ diagonal_decay = slope_rate * diagonal_decay
+ diagonal_decay = torch.where(diagonal_decay >= 0, -diagonal_decay, float("-inf"))
+ diagonal_decay = torch.exp(diagonal_decay)
+
+ return query_decay, key_decay, diagonal_decay
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ batch_size, seq_len, hidden_size = hidden_states.shape
+ num_blocks = (seq_len + self.block_size - 1) // self.block_size
+
+ qkv_states = self.act_fn(self.qkv_proj(hidden_states))
+ qkv_states = qkv_states.reshape(batch_size, seq_len, self.num_attention_heads, 3 * self.head_dim)
+
+ query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=3)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # calculated (K.T @ V) and saved as cache
+ attn_weights_inter = None
+ if past_key_values is not None:
+ attn_weights_inter = past_key_values.get_linear_cache(self.layer_idx)
+
+ if attn_weights_inter is None:
+ attn_weights_inter = torch.zeros(batch_size, self.num_attention_heads, self.head_dim, self.head_dim).to(
+ value_states
+ )
+
+ # apply attention_mask
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(dtype=torch.bool) # Ensure it's a boolean tensor
+ value_states = value_states.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(-1), 0)
+
+ attn_output = []
+ for i in range(num_blocks):
+ start_idx = i * self.block_size
+ end_idx = min(start_idx + self.block_size, seq_len)
+ current_block_size = end_idx - start_idx
+
+ current_query_states = query_states[:, :, start_idx:end_idx]
+ current_key_states = key_states[:, :, start_idx:end_idx]
+ current_value_states = value_states[:, :, start_idx:end_idx]
+
+ current_query_decay = self.query_decay[:, :current_block_size]
+ current_key_decay = self.key_decay[:, -current_block_size:]
+ current_diagonal_decay = self.diagonal_decay[:, :, :current_block_size, :current_block_size]
+ block_decay = torch.exp(-self.slope_rate * current_block_size)
+
+ # intra: ( Q @ K.T ) @ V -> QK * V
+ attn_weights_intra = torch.matmul(current_query_states, current_key_states.transpose(-1, -2))
+ attn_output_intra = torch.matmul(attn_weights_intra * current_diagonal_decay, current_value_states)
+
+ # inter: Q @ ( K.T @ V ) -> Q * KV
+ attn_output_inter = torch.matmul(current_query_states * current_query_decay, attn_weights_inter)
+
+ # final attention output
+ current_attn_output = attn_output_inter + attn_output_intra
+ attn_output.append(current_attn_output)
+
+ # calculate attn_weights_inter for next block or cache
+ next_attn_weights_inter = torch.matmul(
+ (current_key_states * current_key_decay).transpose(-1, -2), current_value_states
+ )
+ attn_weights_inter = attn_weights_inter * block_decay + next_attn_weights_inter
+
+ else:
+ ratio = torch.exp(-self.slope_rate)
+ attn_output = []
+ for i in range(seq_len):
+ current_query_states = query_states[:, :, i : i + 1]
+ current_key_states = key_states[:, :, i : i + 1]
+ current_value_states = value_states[:, :, i : i + 1]
+
+ current_attn_weights_inter = torch.matmul(current_key_states.transpose(-1, -2), current_value_states)
+ attn_weights_inter = ratio * attn_weights_inter + current_attn_weights_inter
+ current_attn_output = torch.matmul(current_query_states, attn_weights_inter)
+
+ attn_output.append(current_attn_output)
+
+ # concatenate attention outputs over all blocks
+ attn_output = torch.cat(attn_output, dim=-2)
+
+ # final output projection
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(batch_size, seq_len, self.num_attention_heads * self.head_dim)
+ attn_output = self.norm(attn_output)
+ attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output
+ attn_output = self.out_proj(attn_output)
+
+ # update cache
+ if past_key_values is not None:
+ past_key_values.set_linear_cache(self.layer_idx, attn_weights_inter)
+
+ return attn_output, attn_weights_inter
+
+
+class MiniMaxAttention(MixtralAttention):
+ pass
+
+
+class MiniMaxSparseMoeBlock(MixtralSparseMoeBlock):
+ pass
+
+
+class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer):
+ def __init__(self, config: MiniMaxConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+
+ self.layer_idx = layer_idx
+ self.layer_type = config.layer_types[layer_idx]
+ self.mlp_alpha_factor = config.mlp_alpha_factor
+ self.mlp_beta_factor = config.mlp_beta_factor
+
+ if self.layer_type == "linear_attention":
+ self.self_attn = MiniMaxLightningAttention(config, layer_idx)
+ self.attn_alpha_factor = config.linear_attn_alpha_factor
+ self.attn_beta_factor = config.linear_attn_beta_factor
+ else:
+ self.self_attn = MiniMaxAttention(config, layer_idx)
+ self.attn_alpha_factor = config.full_attn_alpha_factor
+ self.attn_beta_factor = config.full_attn_beta_factor
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ output_router_logits: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ attention_mask (`torch.Tensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_router_logits (`bool`, *optional*):
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
+ should not be returned during inference.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
+ hidden_states = self.input_layernorm(hidden_states)
+ residual = hidden_states
+
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor
+
+ # Fully Connected
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ residual = hidden_states
+ hidden_states, _ = self.block_sparse_moe(hidden_states)
+ hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor
+
+ return hidden_states
+
+
+class MiniMaxPreTrainedModel(MixtralPreTrainedModel):
+ _can_compile_fullgraph = False
+ _can_record_outputs = {
+ "router_logits": OutputRecorder(MiniMaxSparseMoeBlock, index=1),
+ "hidden_states": MiniMaxDecoderLayer,
+ "attentions": [MiniMaxAttention, MiniMaxLightningAttention],
+ }
+
+
+class MiniMaxModel(MixtralModel):
+ @check_model_inputs
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[MiniMaxCache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if use_cache and past_key_values is None:
+ past_key_values = MiniMaxCache()
+ elif use_cache and not isinstance(past_key_values, MiniMaxCache):
+ raise ValueError(
+ f"MiniMax uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}."
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
+ causal_mask = mask_function(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers:
+ if decoder_layer.layer_type == "full_attention":
+ input_attention_mask = causal_mask
+ else:
+ # lightning attention uses original attention_mask, and uses it only for the first step
+ input_attention_mask = attention_mask
+
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=input_attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+class MiniMaxForCausalLM(MixtralForCausalLM):
+ def forward(self, **super_kwargs):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MiniMaxForCausalLM
+
+ >>> model = MiniMaxForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ return super().forward(**super_kwargs)
+
+
+class MiniMaxForSequenceClassification(MixtralForSequenceClassification):
+ pass
+
+
+class MiniMaxForTokenClassification(MixtralForTokenClassification):
+ pass
+
+
+class MiniMaxForQuestionAnswering(MixtralForQuestionAnswering):
+ pass
+
+
+__all__ = [
+ "MiniMaxConfig",
+ "MiniMaxPreTrainedModel",
+ "MiniMaxModel",
+ "MiniMaxForCausalLM",
+ "MiniMaxForSequenceClassification",
+ "MiniMaxForTokenClassification",
+ "MiniMaxForQuestionAnswering",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mlcd/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mlcd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3556177d5eb8e493aa5e051785b980b0e10d9c4a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mlcd/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_mlcd import *
+ from .modeling_mlcd import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mlcd/configuration_mlcd.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mlcd/configuration_mlcd.py
new file mode 100644
index 0000000000000000000000000000000000000000..f28a5f1a7cab3d453c3bf07caea9fa9fd25a593b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mlcd/configuration_mlcd.py
@@ -0,0 +1,117 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/mlcd/modular_mlcd.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_mlcd.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+
+
+class MLCDVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MLCDVisionModel`]. It is used to instantiate a MLCD
+ vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the vision encoder of the MLCD
+ [DeepGlint-AI/mlcd-vit-bigG-patch14-336](https://huggingface.co/DeepGlint-AI/mlcd-vit-bigG-patch14-336) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 1664):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 8192):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ projection_dim (`int`, *optional*, defaults to 1024):
+ Dimensionality of text and vision projection layers.
+ num_hidden_layers (`int`, *optional*, defaults to 48):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ image_size (`int`, *optional*, defaults to 336):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 1.0):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+
+ Example:
+
+ ```python
+ >>> from transformers import MLCDVisionConfig, MLCDVisionModel
+
+ >>> # Initializing a MLCDVisionConfig with DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
+ >>> configuration = MLCDVisionConfig()
+
+ >>> # Initializing a MLCDVisionModel (with random weights) from the DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
+ >>> model = MLCDVisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "mlcd_vision_model"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=1664,
+ intermediate_size=8192,
+ num_hidden_layers=48,
+ num_attention_heads=16,
+ num_key_value_groups=1,
+ num_channels=3,
+ image_size=336,
+ patch_size=14,
+ hidden_act="gelu",
+ layer_norm_eps=1e-5,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_groups = num_key_value_groups
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+
+
+__all__ = ["MLCDVisionConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mlcd/modeling_mlcd.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mlcd/modeling_mlcd.py
new file mode 100644
index 0000000000000000000000000000000000000000..5379c5313726dd2da0401b766580b569e354e277
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mlcd/modeling_mlcd.py
@@ -0,0 +1,611 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/mlcd/modular_mlcd.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_mlcd.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...activations import ACT2FN
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, torch_int
+from .configuration_mlcd import MLCDVisionConfig
+
+
+class MLCDMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class MLCDRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor:
+ """
+ Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size.
+
+ Args:
+ num_patches_height (int): Number of patches in the height dimension.
+ num_patches_width (int): Number of patches in the width dimension.
+
+ Returns:
+ torch.Tensor: Rotary positional embeddings for the given grid size.
+ """
+ # Generate position IDs for height and width dimensions
+ hpos_ids = (
+ torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width)
+ )
+ wpos_ids = (
+ torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1)
+ )
+
+ # Flatten and stack the position IDs
+ pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1)
+
+ # Generate the full rotary positional embeddings for the maximum grid size
+ max_grid_size = max(num_patches_height, num_patches_width)
+ seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
+
+ # Select and flatten the embeddings based on the position IDs
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+
+ return rotary_pos_emb
+
+
+class MLCDVisionEmbeddings(nn.Module):
+ def __init__(self, config: MLCDVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=False,
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ position_embedding = self.position_embedding.weight.unsqueeze(0)
+ num_positions = position_embedding.shape[1] - 1
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embedding(self.position_ids)
+
+ class_pos_embed = position_embedding[:, :1]
+ patch_pos_embed = position_embedding[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ batch_size = pixel_values.shape[0]
+ target_dtype = self.patch_embedding.weight.dtype
+ # patch_embeds -> shape = [batch, width, grid, grid]
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+
+ return embeddings
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def apply_rotary_pos_emb_vision(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
+ orig_q_dtype = q.dtype
+ orig_k_dtype = k.dtype
+ q, k = q.float(), k.float()
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ q_embed = q_embed.to(orig_q_dtype)
+ k_embed = k_embed.to(orig_k_dtype)
+ return q_embed, k_embed
+
+
+class MLCDAttention(nn.Module):
+ """Multi-headed attention with RoPE. Refer to papers:
+ - Attention is all you need:
+ https://huggingface.co/papers/1706.03762
+ - RoFormer: Enhanced Transformer with Rotary Position Embedding:
+ https://huggingface.co/papers/2104.09864
+ """
+
+ def __init__(self, config: MLCDVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+ self.is_causal = False
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.num_key_value_groups = config.num_key_value_groups
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+ batch_size, seq_length = hidden_states.shape[:-1]
+
+ # Each of shape: [batch_size, seq_length, num_heads, head_dim]
+ query_states = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
+ key_states = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
+ value_states = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
+
+ # Apply positional embeddings
+ cos = position_embeddings[0].unsqueeze(0).float()
+ sin = position_embeddings[1].unsqueeze(0).float()
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
+
+ # Each of shape: [batch_size, num_heads, seq_length, head_dim]
+ query_states = query_states.permute(0, 2, 1, 3).contiguous()
+ key_states = key_states.permute(0, 2, 1, 3).contiguous()
+ value_states = value_states.permute(0, 2, 1, 3).contiguous()
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scale,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+
+ attn_output = attn_output.permute(1, 0, 2, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim]
+ attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim]
+ attn_output = self.out_proj(attn_output)
+ attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim]
+ return attn_output, attn_weights
+
+
+class MLCDEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: MLCDVisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = MLCDAttention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = MLCDMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
+ Represents the hidden states from the previous layer or the input embeddings.
+ position_embeddings (`tuple[torch.Tensor, torch.Tensor]`):
+ A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
+ Represents absolute positional embeddings for the query and key in the attention mechanism.
+ attention_mask (`torch.FloatTensor`):
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class MLCDEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`MLCDEncoderLayer`].
+
+ Args:
+ config: MLCDVisionConfig
+ """
+
+ def __init__(self, config: MLCDVisionConfig):
+ """Overwrite dummy `MLCDConfig` to `MLCDVisionConfig`."""
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([MLCDEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ inputs_embeds: torch.FloatTensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ position_embeddings (`tuple[torch.Tensor, torch.Tensor]`):
+ A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
+ Represents absolute positional embeddings for the query and key in the attention mechanism.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ layer_outputs = encoder_layer(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_states,
+ attentions=all_attentions,
+ )
+
+
+class MLCDVisionTransformer(nn.Module):
+ def __init__(self, config: MLCDVisionConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = MLCDVisionEmbeddings(config)
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.encoder = MLCDEncoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
+ self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2))
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ num_patches_height = pixel_values.shape[-2] // self.config.patch_size
+ num_patches_width = pixel_values.shape[-1] // self.config.patch_size
+ rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width)
+ rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device)
+ rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ hidden_states = self.embeddings(pixel_values)
+ hidden_states = self.pre_layrnorm(hidden_states)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ position_embeddings=position_embeddings,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ pooled_output = last_hidden_state[:, 0, :]
+ pooled_output = self.post_layernorm(pooled_output)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring
+class MLCDPreTrainedModel(PreTrainedModel):
+ config: MLCDVisionConfig
+ base_model_prefix = "mlcd"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ factor = self.config.initializer_factor
+ if isinstance(module, MLCDVisionEmbeddings):
+ factor = self.config.initializer_factor
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
+ elif isinstance(module, MLCDAttention):
+ factor = self.config.initializer_factor
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+ out_proj_std = (module.embed_dim**-0.5) * factor
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
+ elif isinstance(module, MLCDMLP):
+ factor = self.config.initializer_factor
+ in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
+ nn.init.normal_(module.fc1.weight, std=fc_std)
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
+ elif isinstance(module, MLCDVisionTransformer):
+ factor = self.config.initializer_factor
+ pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor
+ nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+@auto_docstring(
+ custom_intro="""
+ The vision model from M_L_C_D without any head or projection on top.
+ """
+)
+class MLCDVisionModel(MLCDPreTrainedModel):
+ config: MLCDVisionConfig
+ main_input_name = "pixel_values"
+ _no_split_modules = ["MLCDEncoderLayer"]
+
+ def __init__(self, config: MLCDVisionConfig):
+ super().__init__(config)
+ self.vision_model = MLCDVisionTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ r"""
+ Example:
+
+ ```python
+ >>> import requests
+ >>> from PIL import Image
+ >>> from transformers import AutoProcessor, MLCDVisionModel
+ >>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
+ >>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs, output_attentions=True)
+
+ >>> features = outputs.last_hidden_state
+ >>> print(f"Extracted features shape: {features.shape}")
+ >>> print(f"Number of attention layers: {len(outputs.attentions)}")
+ >>> print(f"Attention shape: {outputs.attentions[0].shape}")
+ ```"""
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+__all__ = ["MLCDPreTrainedModel", "MLCDVisionModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mlcd/modular_mlcd.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mlcd/modular_mlcd.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcc18ab2b1c8f123be84c502d534c6000882a683
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mlcd/modular_mlcd.py
@@ -0,0 +1,529 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, logging
+from ..clip.modeling_clip import (
+ CLIPMLP,
+ CLIPAttention,
+ CLIPEncoder,
+ CLIPEncoderLayer,
+ CLIPVisionEmbeddings,
+ CLIPVisionModel,
+ CLIPVisionTransformer,
+)
+from ..llama.modeling_llama import eager_attention_forward
+from ..qwen2_vl.modeling_qwen2_vl import VisionRotaryEmbedding, apply_rotary_pos_emb_vision
+
+
+logger = logging.get_logger(__name__)
+
+
+class MLCDVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MLCDVisionModel`]. It is used to instantiate a MLCD
+ vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the vision encoder of the MLCD
+ [DeepGlint-AI/mlcd-vit-bigG-patch14-336](https://huggingface.co/DeepGlint-AI/mlcd-vit-bigG-patch14-336) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 1664):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 8192):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ projection_dim (`int`, *optional*, defaults to 1024):
+ Dimensionality of text and vision projection layers.
+ num_hidden_layers (`int`, *optional*, defaults to 48):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ image_size (`int`, *optional*, defaults to 336):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 1.0):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+
+ Example:
+
+ ```python
+ >>> from transformers import MLCDVisionConfig, MLCDVisionModel
+
+ >>> # Initializing a MLCDVisionConfig with DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
+ >>> configuration = MLCDVisionConfig()
+
+ >>> # Initializing a MLCDVisionModel (with random weights) from the DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
+ >>> model = MLCDVisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "mlcd_vision_model"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=1664,
+ intermediate_size=8192,
+ num_hidden_layers=48,
+ num_attention_heads=16,
+ num_key_value_groups=1,
+ num_channels=3,
+ image_size=336,
+ patch_size=14,
+ hidden_act="gelu",
+ layer_norm_eps=1e-5,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_groups = num_key_value_groups
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+
+
+class MLCDMLP(CLIPMLP):
+ pass
+
+
+class MLCDRotaryEmbedding(VisionRotaryEmbedding):
+ def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor:
+ """
+ Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size.
+
+ Args:
+ num_patches_height (int): Number of patches in the height dimension.
+ num_patches_width (int): Number of patches in the width dimension.
+
+ Returns:
+ torch.Tensor: Rotary positional embeddings for the given grid size.
+ """
+ # Generate position IDs for height and width dimensions
+ hpos_ids = (
+ torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width)
+ )
+ wpos_ids = (
+ torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1)
+ )
+
+ # Flatten and stack the position IDs
+ pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1)
+
+ # Generate the full rotary positional embeddings for the maximum grid size
+ max_grid_size = max(num_patches_height, num_patches_width)
+ seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
+
+ # Select and flatten the embeddings based on the position IDs
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+
+ return rotary_pos_emb
+
+
+class MLCDVisionEmbeddings(CLIPVisionEmbeddings):
+ def __init__(self, config: MLCDVisionConfig):
+ super().__init__(config)
+ del self.position_embedding
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ batch_size = pixel_values.shape[0]
+ target_dtype = self.patch_embedding.weight.dtype
+ # patch_embeds -> shape = [batch, width, grid, grid]
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+
+ return embeddings
+
+
+class MLCDAttention(CLIPAttention):
+ """Multi-headed attention with RoPE. Refer to papers:
+ - Attention is all you need:
+ https://huggingface.co/papers/1706.03762
+ - RoFormer: Enhanced Transformer with Rotary Position Embedding:
+ https://huggingface.co/papers/2104.09864
+ """
+
+ def __init__(self, config: MLCDVisionConfig):
+ super().__init__(config)
+ self.num_key_value_groups = config.num_key_value_groups
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ batch_size, seq_length = hidden_states.shape[:-1]
+
+ # Each of shape: [batch_size, seq_length, num_heads, head_dim]
+ query_states = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
+ key_states = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
+ value_states = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
+
+ # Apply positional embeddings
+ cos = position_embeddings[0].unsqueeze(0).float()
+ sin = position_embeddings[1].unsqueeze(0).float()
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
+
+ # Each of shape: [batch_size, num_heads, seq_length, head_dim]
+ query_states = query_states.permute(0, 2, 1, 3).contiguous()
+ key_states = key_states.permute(0, 2, 1, 3).contiguous()
+ value_states = value_states.permute(0, 2, 1, 3).contiguous()
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scale,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+
+ attn_output = attn_output.permute(1, 0, 2, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim]
+ attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim]
+ attn_output = self.out_proj(attn_output)
+ attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim]
+ return attn_output, attn_weights
+
+
+class MLCDEncoderLayer(CLIPEncoderLayer):
+ def __init__(self, config: MLCDVisionConfig):
+ super().__init__(config)
+ self.self_attn = MLCDAttention(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
+ Represents the hidden states from the previous layer or the input embeddings.
+ position_embeddings (`tuple[torch.Tensor, torch.Tensor]`):
+ A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
+ Represents absolute positional embeddings for the query and key in the attention mechanism.
+ attention_mask (`torch.FloatTensor`):
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class MLCDEncoder(CLIPEncoder):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`MLCDEncoderLayer`].
+
+ Args:
+ config: MLCDVisionConfig
+ """
+
+ def __init__(self, config: MLCDVisionConfig):
+ """Overwrite dummy `MLCDConfig` to `MLCDVisionConfig`."""
+ super().__init__(config)
+
+ def forward(
+ self,
+ inputs_embeds: torch.FloatTensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ position_embeddings (`tuple[torch.Tensor, torch.Tensor]`):
+ A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
+ Represents absolute positional embeddings for the query and key in the attention mechanism.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ layer_outputs = encoder_layer(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_states,
+ attentions=all_attentions,
+ )
+
+
+class MLCDVisionTransformer(CLIPVisionTransformer):
+ def __init__(self, config: MLCDVisionConfig):
+ super().__init__(config)
+ self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
+ self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2))
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ num_patches_height = pixel_values.shape[-2] // self.config.patch_size
+ num_patches_width = pixel_values.shape[-1] // self.config.patch_size
+ rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width)
+ rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device)
+ rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ hidden_states = self.embeddings(pixel_values)
+ hidden_states = self.pre_layrnorm(hidden_states)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ position_embeddings=position_embeddings,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ pooled_output = last_hidden_state[:, 0, :]
+ pooled_output = self.post_layernorm(pooled_output)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring
+class MLCDPreTrainedModel(PreTrainedModel):
+ config: MLCDVisionConfig
+ base_model_prefix = "mlcd"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ factor = self.config.initializer_factor
+ if isinstance(module, MLCDVisionEmbeddings):
+ factor = self.config.initializer_factor
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
+ elif isinstance(module, MLCDAttention):
+ factor = self.config.initializer_factor
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+ out_proj_std = (module.embed_dim**-0.5) * factor
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
+ elif isinstance(module, MLCDMLP):
+ factor = self.config.initializer_factor
+ in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
+ nn.init.normal_(module.fc1.weight, std=fc_std)
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
+ elif isinstance(module, MLCDVisionTransformer):
+ factor = self.config.initializer_factor
+ pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor
+ nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class MLCDVisionModel(CLIPVisionModel):
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ r"""
+ Example:
+
+ ```python
+ >>> import requests
+ >>> from PIL import Image
+ >>> from transformers import AutoProcessor, MLCDVisionModel
+ >>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
+ >>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs, output_attentions=True)
+
+ >>> features = outputs.last_hidden_state
+ >>> print(f"Extracted features shape: {features.shape}")
+ >>> print(f"Number of attention layers: {len(outputs.attentions)}")
+ >>> print(f"Attention shape: {outputs.attentions[0].shape}")
+ ```"""
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+__all__ = [
+ "MLCDVisionConfig",
+ "MLCDPreTrainedModel",
+ "MLCDVisionModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bce83216c35546fe01cf96c738cbe9c2f3582486
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_mobilenet_v1 import *
+ from .feature_extraction_mobilenet_v1 import *
+ from .image_processing_mobilenet_v1 import *
+ from .image_processing_mobilenet_v1_fast import *
+ from .modeling_mobilenet_v1 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..c18d72d55003dc88e8fbcfe2ee80a18975d23dbc
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py
@@ -0,0 +1,126 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MobileNetV1 model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MobileNetV1Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MobileNetV1Model`]. It is used to instantiate a
+ MobileNetV1 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the MobileNetV1
+ [google/mobilenet_v1_1.0_224](https://huggingface.co/google/mobilenet_v1_1.0_224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ depth_multiplier (`float`, *optional*, defaults to 1.0):
+ Shrinks or expands the number of channels in each layer. Default is 1.0, which starts the network with 32
+ channels. This is sometimes also called "alpha" or "width multiplier".
+ min_depth (`int`, *optional*, defaults to 8):
+ All layers will have at least this many channels.
+ hidden_act (`str` or `function`, *optional*, defaults to `"relu6"`):
+ The non-linear activation function (function or string) in the Transformer encoder and convolution layers.
+ tf_padding (`bool`, *optional*, defaults to `True`):
+ Whether to use TensorFlow padding rules on the convolution layers.
+ classifier_dropout_prob (`float`, *optional*, defaults to 0.999):
+ The dropout ratio for attached classifiers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 0.001):
+ The epsilon used by the layer normalization layers.
+
+ Example:
+
+ ```python
+ >>> from transformers import MobileNetV1Config, MobileNetV1Model
+
+ >>> # Initializing a "mobilenet_v1_1.0_224" style configuration
+ >>> configuration = MobileNetV1Config()
+
+ >>> # Initializing a model from the "mobilenet_v1_1.0_224" style configuration
+ >>> model = MobileNetV1Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "mobilenet_v1"
+
+ def __init__(
+ self,
+ num_channels=3,
+ image_size=224,
+ depth_multiplier=1.0,
+ min_depth=8,
+ hidden_act="relu6",
+ tf_padding=True,
+ classifier_dropout_prob=0.999,
+ initializer_range=0.02,
+ layer_norm_eps=0.001,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if depth_multiplier <= 0:
+ raise ValueError("depth_multiplier must be greater than zero.")
+
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.depth_multiplier = depth_multiplier
+ self.min_depth = min_depth
+ self.hidden_act = hidden_act
+ self.tf_padding = tf_padding
+ self.classifier_dropout_prob = classifier_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+
+
+class MobileNetV1OnnxConfig(OnnxConfig):
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict([("pixel_values", {0: "batch"})])
+
+ @property
+ def outputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "image-classification":
+ return OrderedDict([("logits", {0: "batch"})])
+ else:
+ return OrderedDict([("last_hidden_state", {0: "batch"}), ("pooler_output", {0: "batch"})])
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-4
+
+
+__all__ = ["MobileNetV1Config", "MobileNetV1OnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/feature_extraction_mobilenet_v1.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/feature_extraction_mobilenet_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..02a5401bc145996d1126641ee656180a48c92e20
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/feature_extraction_mobilenet_v1.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for MobileNetV1."""
+
+import warnings
+
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_mobilenet_v1 import MobileNetV1ImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class MobileNetV1FeatureExtractor(MobileNetV1ImageProcessor):
+ def __init__(self, *args, **kwargs) -> None:
+ warnings.warn(
+ "The class MobileNetV1FeatureExtractor is deprecated and will be removed in version 5 of Transformers."
+ " Please use MobileNetV1ImageProcessor instead.",
+ FutureWarning,
+ )
+ super().__init__(*args, **kwargs)
+
+
+__all__ = ["MobileNetV1FeatureExtractor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fa3f443c53b1a315a917a8167b100128b45046f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py
@@ -0,0 +1,307 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for MobileNetV1."""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ get_resize_output_image_size,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, logging
+from ...utils.import_utils import requires
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class MobileNetV1ImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a MobileNetV1 image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in the `preprocess` method.
+ size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 256}`):
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
+ is padded with 0's and then center cropped. Can be overridden by the `do_center_crop` parameter in the
+ `preprocess` method.
+ crop_size (`dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
+ Can be overridden by the `crop_size` parameter in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_normalize:
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_center_crop: bool = True,
+ crop_size: Optional[dict[str, int]] = None,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 256}
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size)
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
+ resized to keep the input aspect ratio.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ default_to_square = True
+ if "shortest_edge" in size:
+ size = size["shortest_edge"]
+ default_to_square = False
+ elif "height" in size and "width" in size:
+ size = (size["height"], size["width"])
+ else:
+ raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
+
+ output_size = get_resize_output_image_size(
+ image,
+ size=size,
+ default_to_square=default_to_square,
+ input_data_format=input_data_format,
+ )
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[dict[str, int]] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
+ an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use if `do_normalize` is set to `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size)
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ all_images = []
+ for image in images:
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_center_crop:
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ all_images.append(image)
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ for image in all_images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["MobileNetV1ImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..e716553a6d102fcd5d32af50f9e28c5397812287
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py
@@ -0,0 +1,43 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for MobileNetV1."""
+
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ Unpack,
+)
+from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling
+from ...utils import auto_docstring
+
+
+@auto_docstring
+class MobileNetV1ImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ size = {"shortest_edge": 256}
+ default_to_square = False
+ crop_size = {"height": 224, "width": 224}
+ do_resize = True
+ do_center_crop = True
+ do_rescale = True
+ do_normalize = True
+
+ def __init__(self, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> None:
+ super().__init__(**kwargs)
+
+
+__all__ = ["MobileNetV1ImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..25997a46790c009bc62e5b4323d378386e3c1718
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py
@@ -0,0 +1,414 @@
+# coding=utf-8
+# Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch MobileNetV1 model."""
+
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging
+from .configuration_mobilenet_v1 import MobileNetV1Config
+
+
+logger = logging.get_logger(__name__)
+
+
+def _build_tf_to_pytorch_map(model, config, tf_weights=None):
+ """
+ A map of modules from TF to PyTorch.
+ """
+
+ tf_to_pt_map = {}
+
+ if isinstance(model, MobileNetV1ForImageClassification):
+ backbone = model.mobilenet_v1
+ else:
+ backbone = model
+
+ prefix = "MobilenetV1/Conv2d_0/"
+ tf_to_pt_map[prefix + "weights"] = backbone.conv_stem.convolution.weight
+ tf_to_pt_map[prefix + "BatchNorm/beta"] = backbone.conv_stem.normalization.bias
+ tf_to_pt_map[prefix + "BatchNorm/gamma"] = backbone.conv_stem.normalization.weight
+ tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.normalization.running_mean
+ tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.normalization.running_var
+
+ for i in range(13):
+ tf_index = i + 1
+ pt_index = i * 2
+
+ pointer = backbone.layer[pt_index]
+ prefix = f"MobilenetV1/Conv2d_{tf_index}_depthwise/"
+ tf_to_pt_map[prefix + "depthwise_weights"] = pointer.convolution.weight
+ tf_to_pt_map[prefix + "BatchNorm/beta"] = pointer.normalization.bias
+ tf_to_pt_map[prefix + "BatchNorm/gamma"] = pointer.normalization.weight
+ tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.normalization.running_mean
+ tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.normalization.running_var
+
+ pointer = backbone.layer[pt_index + 1]
+ prefix = f"MobilenetV1/Conv2d_{tf_index}_pointwise/"
+ tf_to_pt_map[prefix + "weights"] = pointer.convolution.weight
+ tf_to_pt_map[prefix + "BatchNorm/beta"] = pointer.normalization.bias
+ tf_to_pt_map[prefix + "BatchNorm/gamma"] = pointer.normalization.weight
+ tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.normalization.running_mean
+ tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.normalization.running_var
+
+ if isinstance(model, MobileNetV1ForImageClassification):
+ prefix = "MobilenetV1/Logits/Conv2d_1c_1x1/"
+ tf_to_pt_map[prefix + "weights"] = model.classifier.weight
+ tf_to_pt_map[prefix + "biases"] = model.classifier.bias
+
+ return tf_to_pt_map
+
+
+def load_tf_weights_in_mobilenet_v1(model, config, tf_checkpoint_path):
+ """Load TensorFlow checkpoints in a PyTorch model."""
+ try:
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_checkpoint_path)
+ tf_weights = {}
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_checkpoint_path, name)
+ tf_weights[name] = array
+
+ # Build TF to PyTorch weights loading map
+ tf_to_pt_map = _build_tf_to_pytorch_map(model, config, tf_weights)
+
+ for name, pointer in tf_to_pt_map.items():
+ logger.info(f"Importing {name}")
+ if name not in tf_weights:
+ logger.info(f"{name} not in tf pre-trained weights, skipping")
+ continue
+
+ array = tf_weights[name]
+
+ if "depthwise_weights" in name:
+ logger.info("Transposing depthwise")
+ array = np.transpose(array, (2, 3, 0, 1))
+ elif "weights" in name:
+ logger.info("Transposing")
+ if len(pointer.shape) == 2: # copying into linear layer
+ array = array.squeeze().transpose()
+ else:
+ array = np.transpose(array, (3, 2, 0, 1))
+
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+
+ logger.info(f"Initialize PyTorch weight {name} {array.shape}")
+ pointer.data = torch.from_numpy(array)
+
+ tf_weights.pop(name, None)
+ tf_weights.pop(name + "/RMSProp", None)
+ tf_weights.pop(name + "/RMSProp_1", None)
+ tf_weights.pop(name + "/ExponentialMovingAverage", None)
+
+ logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}")
+ return model
+
+
+def apply_tf_padding(features: torch.Tensor, conv_layer: nn.Conv2d) -> torch.Tensor:
+ """
+ Apply TensorFlow-style "SAME" padding to a convolution layer. See the notes at:
+ https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2
+ """
+ in_height, in_width = features.shape[-2:]
+ stride_height, stride_width = conv_layer.stride
+ kernel_height, kernel_width = conv_layer.kernel_size
+
+ if in_height % stride_height == 0:
+ pad_along_height = max(kernel_height - stride_height, 0)
+ else:
+ pad_along_height = max(kernel_height - (in_height % stride_height), 0)
+
+ if in_width % stride_width == 0:
+ pad_along_width = max(kernel_width - stride_width, 0)
+ else:
+ pad_along_width = max(kernel_width - (in_width % stride_width), 0)
+
+ pad_left = pad_along_width // 2
+ pad_right = pad_along_width - pad_left
+ pad_top = pad_along_height // 2
+ pad_bottom = pad_along_height - pad_top
+
+ padding = (pad_left, pad_right, pad_top, pad_bottom)
+ return nn.functional.pad(features, padding, "constant", 0.0)
+
+
+class MobileNetV1ConvLayer(nn.Module):
+ def __init__(
+ self,
+ config: MobileNetV1Config,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: Optional[int] = 1,
+ groups: Optional[int] = 1,
+ bias: bool = False,
+ use_normalization: Optional[bool] = True,
+ use_activation: Optional[Union[bool, str]] = True,
+ ) -> None:
+ super().__init__()
+ self.config = config
+
+ if in_channels % groups != 0:
+ raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.")
+ if out_channels % groups != 0:
+ raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
+
+ padding = 0 if config.tf_padding else int((kernel_size - 1) / 2)
+
+ self.convolution = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ bias=bias,
+ padding_mode="zeros",
+ )
+
+ if use_normalization:
+ self.normalization = nn.BatchNorm2d(
+ num_features=out_channels,
+ eps=config.layer_norm_eps,
+ momentum=0.9997,
+ affine=True,
+ track_running_stats=True,
+ )
+ else:
+ self.normalization = None
+
+ if use_activation:
+ if isinstance(use_activation, str):
+ self.activation = ACT2FN[use_activation]
+ elif isinstance(config.hidden_act, str):
+ self.activation = ACT2FN[config.hidden_act]
+ else:
+ self.activation = config.hidden_act
+ else:
+ self.activation = None
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ if self.config.tf_padding:
+ features = apply_tf_padding(features, self.convolution)
+ features = self.convolution(features)
+ if self.normalization is not None:
+ features = self.normalization(features)
+ if self.activation is not None:
+ features = self.activation(features)
+ return features
+
+
+@auto_docstring
+class MobileNetV1PreTrainedModel(PreTrainedModel):
+ config: MobileNetV1Config
+ load_tf_weights = load_tf_weights_in_mobilenet_v1
+ base_model_prefix = "mobilenet_v1"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = False
+ _no_split_modules = []
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.BatchNorm2d):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class MobileNetV1Model(MobileNetV1PreTrainedModel):
+ def __init__(self, config: MobileNetV1Config, add_pooling_layer: bool = True):
+ r"""
+ add_pooling_layer (bool, *optional*, defaults to `True`):
+ Whether to add a pooling layer
+ """
+ super().__init__(config)
+ self.config = config
+
+ depth = 32
+ out_channels = max(int(depth * config.depth_multiplier), config.min_depth)
+
+ self.conv_stem = MobileNetV1ConvLayer(
+ config,
+ in_channels=config.num_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=2,
+ )
+
+ strides = [1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1]
+
+ self.layer = nn.ModuleList()
+ for i in range(13):
+ in_channels = out_channels
+
+ if strides[i] == 2 or i == 0:
+ depth *= 2
+ out_channels = max(int(depth * config.depth_multiplier), config.min_depth)
+
+ self.layer.append(
+ MobileNetV1ConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=3,
+ stride=strides[i],
+ groups=in_channels,
+ )
+ )
+
+ self.layer.append(
+ MobileNetV1ConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ )
+ )
+
+ self.pooler = nn.AdaptiveAvgPool2d((1, 1)) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _prune_heads(self, heads_to_prune):
+ raise NotImplementedError
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.conv_stem(pixel_values)
+
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, layer_module in enumerate(self.layer):
+ hidden_states = layer_module(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ last_hidden_state = hidden_states
+
+ if self.pooler is not None:
+ pooled_output = torch.flatten(self.pooler(last_hidden_state), start_dim=1)
+ else:
+ pooled_output = None
+
+ if not return_dict:
+ return tuple(v for v in [last_hidden_state, pooled_output, all_hidden_states] if v is not None)
+
+ return BaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=all_hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ MobileNetV1 model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """
+)
+class MobileNetV1ForImageClassification(MobileNetV1PreTrainedModel):
+ def __init__(self, config: MobileNetV1Config) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.mobilenet_v1 = MobileNetV1Model(config)
+
+ last_hidden_size = self.mobilenet_v1.layer[-1].convolution.out_channels
+
+ # Classifier head
+ self.dropout = nn.Dropout(config.classifier_dropout_prob, inplace=True)
+ self.classifier = nn.Linear(last_hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.mobilenet_v1(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.classifier(self.dropout(pooled_output))
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutputWithNoAttention(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ )
+
+
+__all__ = [
+ "MobileNetV1ForImageClassification",
+ "MobileNetV1Model",
+ "MobileNetV1PreTrainedModel",
+ "load_tf_weights_in_mobilenet_v1",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6750449a3eae900890eeaffca5c766ba9cba3339
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_mobilevit import *
+ from .feature_extraction_mobilevit import *
+ from .image_processing_mobilevit import *
+ from .image_processing_mobilevit_fast import *
+ from .modeling_mobilevit import *
+ from .modeling_tf_mobilevit import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/configuration_mobilevit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/configuration_mobilevit.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bb22e1590b2d0de857c1e5c801311900fa58b7f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/configuration_mobilevit.py
@@ -0,0 +1,172 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MobileViT model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MobileViTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MobileViTModel`]. It is used to instantiate a
+ MobileViT model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the MobileViT
+ [apple/mobilevit-small](https://huggingface.co/apple/mobilevit-small) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ image_size (`int`, *optional*, defaults to 256):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 2):
+ The size (resolution) of each patch.
+ hidden_sizes (`list[int]`, *optional*, defaults to `[144, 192, 240]`):
+ Dimensionality (hidden size) of the Transformer encoders at each stage.
+ neck_hidden_sizes (`list[int]`, *optional*, defaults to `[16, 32, 64, 96, 128, 160, 640]`):
+ The number of channels for the feature maps of the backbone.
+ num_attention_heads (`int`, *optional*, defaults to 4):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ mlp_ratio (`float`, *optional*, defaults to 2.0):
+ The ratio of the number of channels in the output of the MLP to the number of channels in the input.
+ expand_ratio (`float`, *optional*, defaults to 4.0):
+ Expansion factor for the MobileNetv2 layers.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the Transformer encoder and convolution layers.
+ conv_kernel_size (`int`, *optional*, defaults to 3):
+ The size of the convolutional kernel in the MobileViT layer.
+ output_stride (`int`, *optional*, defaults to 32):
+ The ratio of the spatial resolution of the output to the resolution of the input image.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the Transformer encoder.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ classifier_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for attached classifiers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ aspp_out_channels (`int`, *optional*, defaults to 256):
+ Number of output channels used in the ASPP layer for semantic segmentation.
+ atrous_rates (`list[int]`, *optional*, defaults to `[6, 12, 18]`):
+ Dilation (atrous) factors used in the ASPP layer for semantic segmentation.
+ aspp_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the ASPP layer for semantic segmentation.
+ semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
+ The index that is ignored by the loss function of the semantic segmentation model.
+
+ Example:
+
+ ```python
+ >>> from transformers import MobileViTConfig, MobileViTModel
+
+ >>> # Initializing a mobilevit-small style configuration
+ >>> configuration = MobileViTConfig()
+
+ >>> # Initializing a model from the mobilevit-small style configuration
+ >>> model = MobileViTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "mobilevit"
+
+ def __init__(
+ self,
+ num_channels=3,
+ image_size=256,
+ patch_size=2,
+ hidden_sizes=[144, 192, 240],
+ neck_hidden_sizes=[16, 32, 64, 96, 128, 160, 640],
+ num_attention_heads=4,
+ mlp_ratio=2.0,
+ expand_ratio=4.0,
+ hidden_act="silu",
+ conv_kernel_size=3,
+ output_stride=32,
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.0,
+ classifier_dropout_prob=0.1,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ qkv_bias=True,
+ aspp_out_channels=256,
+ atrous_rates=[6, 12, 18],
+ aspp_dropout_prob=0.1,
+ semantic_loss_ignore_index=255,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.hidden_sizes = hidden_sizes
+ self.neck_hidden_sizes = neck_hidden_sizes
+ self.num_attention_heads = num_attention_heads
+ self.mlp_ratio = mlp_ratio
+ self.expand_ratio = expand_ratio
+ self.hidden_act = hidden_act
+ self.conv_kernel_size = conv_kernel_size
+ self.output_stride = output_stride
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.classifier_dropout_prob = classifier_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+
+ # decode head attributes for semantic segmentation
+ self.aspp_out_channels = aspp_out_channels
+ self.atrous_rates = atrous_rates
+ self.aspp_dropout_prob = aspp_dropout_prob
+ self.semantic_loss_ignore_index = semantic_loss_ignore_index
+
+
+class MobileViTOnnxConfig(OnnxConfig):
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict([("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"})])
+
+ @property
+ def outputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "image-classification":
+ return OrderedDict([("logits", {0: "batch"})])
+ else:
+ return OrderedDict([("last_hidden_state", {0: "batch"}), ("pooler_output", {0: "batch"})])
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-4
+
+
+__all__ = ["MobileViTConfig", "MobileViTOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/feature_extraction_mobilevit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/feature_extraction_mobilevit.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c220df918647341a289e39ef0f885b80dcd9df3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/feature_extraction_mobilevit.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for MobileViT."""
+
+import warnings
+
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_mobilevit import MobileViTImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class MobileViTFeatureExtractor(MobileViTImageProcessor):
+ def __init__(self, *args, **kwargs) -> None:
+ warnings.warn(
+ "The class MobileViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
+ " Please use MobileViTImageProcessor instead.",
+ FutureWarning,
+ )
+ super().__init__(*args, **kwargs)
+
+
+__all__ = ["MobileViTFeatureExtractor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/image_processing_mobilevit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/image_processing_mobilevit.py
new file mode 100644
index 0000000000000000000000000000000000000000..5411023c31047251b94a97b429a4529ae4241672
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/image_processing_mobilevit.py
@@ -0,0 +1,518 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for MobileViT."""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import flip_channel_order, get_resize_output_image_size, resize, to_channel_dimension_format
+from ...image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import (
+ TensorType,
+ filter_out_non_signature_kwargs,
+ is_torch_available,
+ is_torch_tensor,
+ is_vision_available,
+ logging,
+)
+from ...utils.import_utils import requires
+
+
+if is_vision_available():
+ import PIL
+
+if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class MobileViTImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a MobileViT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
+ Controls the size of the output image after resizing. Can be overridden by the `size` parameter in the
+ `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Defines the resampling filter to use if resizing the image. Can be overridden by the `resample` parameter
+ in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
+ image is padded with 0's and then center cropped. Can be overridden by the `do_center_crop` parameter in
+ the `preprocess` method.
+ crop_size (`dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`):
+ Desired output size `(size["height"], size["width"])` when applying center-cropping. Can be overridden by
+ the `crop_size` parameter in the `preprocess` method.
+ do_flip_channel_order (`bool`, *optional*, defaults to `True`):
+ Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order`
+ parameter in the `preprocess` method.
+ do_reduce_labels (`bool`, *optional*, defaults to `False`):
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
+ used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
+ background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
+ `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_center_crop: bool = True,
+ crop_size: Optional[dict[str, int]] = None,
+ do_flip_channel_order: bool = True,
+ do_reduce_labels: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 224}
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_flip_channel_order = do_flip_channel_order
+ self.do_reduce_labels = do_reduce_labels
+
+ # Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize with PILImageResampling.BICUBIC->PILImageResampling.BILINEAR
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
+ resized to keep the input aspect ratio.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ default_to_square = True
+ if "shortest_edge" in size:
+ size = size["shortest_edge"]
+ default_to_square = False
+ elif "height" in size and "width" in size:
+ size = (size["height"], size["width"])
+ else:
+ raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
+
+ output_size = get_resize_output_image_size(
+ image,
+ size=size,
+ default_to_square=default_to_square,
+ input_data_format=input_data_format,
+ )
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def flip_channel_order(
+ self,
+ image: np.ndarray,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Flip the color channels from RGB to BGR or vice versa.
+
+ Args:
+ image (`np.ndarray`):
+ The image, represented as a numpy array.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format)
+
+ # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label
+ def reduce_label(self, label: ImageInput) -> np.ndarray:
+ label = to_numpy_array(label)
+ # Avoid using underflow conversion
+ label[label == 0] = 255
+ label = label - 1
+ label[label == 254] = 255
+ return label
+
+ def __call__(self, images, segmentation_maps=None, **kwargs):
+ """
+ Preprocesses a batch of images and optionally segmentation maps.
+
+ Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
+ passed in as positional arguments.
+ """
+ return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
+
+ def _preprocess(
+ self,
+ image: ImageInput,
+ do_reduce_labels: bool,
+ do_resize: bool,
+ do_rescale: bool,
+ do_center_crop: bool,
+ do_flip_channel_order: bool,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ rescale_factor: Optional[float] = None,
+ crop_size: Optional[dict[str, int]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ if do_reduce_labels:
+ image = self.reduce_label(image)
+
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_center_crop:
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
+
+ if do_flip_channel_order:
+ image = self.flip_channel_order(image, input_data_format=input_data_format)
+
+ return image
+
+ def _preprocess_image(
+ self,
+ image: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[dict[str, int]] = None,
+ do_flip_channel_order: Optional[bool] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """Preprocesses a single image."""
+ # All transformations expect numpy arrays.
+ image = to_numpy_array(image)
+ if do_rescale and is_scaled_image(image):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+
+ image = self._preprocess(
+ image=image,
+ do_reduce_labels=False,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_flip_channel_order=do_flip_channel_order,
+ input_data_format=input_data_format,
+ )
+
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+
+ return image
+
+ def _preprocess_mask(
+ self,
+ segmentation_map: ImageInput,
+ do_reduce_labels: Optional[bool] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[dict[str, int]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """Preprocesses a single mask."""
+ segmentation_map = to_numpy_array(segmentation_map)
+ # Add channel dimension if missing - needed for certain transformations
+ if segmentation_map.ndim == 2:
+ added_channel_dim = True
+ segmentation_map = segmentation_map[None, ...]
+ input_data_format = ChannelDimension.FIRST
+ else:
+ added_channel_dim = False
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
+
+ segmentation_map = self._preprocess(
+ image=segmentation_map,
+ do_reduce_labels=do_reduce_labels,
+ do_resize=do_resize,
+ size=size,
+ resample=PILImageResampling.NEAREST,
+ do_rescale=False,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_flip_channel_order=False,
+ input_data_format=input_data_format,
+ )
+ # Remove extra channel dimension if added for processing
+ if added_channel_dim:
+ segmentation_map = segmentation_map.squeeze(0)
+ segmentation_map = segmentation_map.astype(np.int64)
+ return segmentation_map
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[dict[str, int]] = None,
+ do_flip_channel_order: Optional[bool] = None,
+ do_reduce_labels: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ segmentation_maps (`ImageInput`, *optional*):
+ Segmentation map to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image by rescale factor.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the center crop if `do_center_crop` is set to `True`.
+ do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`):
+ Whether to flip the channel order of the image.
+ do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
+ is used for background, and background itself is not included in all classes of a dataset (e.g.
+ ADE20k). The background label will be replaced by 255.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ do_flip_channel_order = (
+ do_flip_channel_order if do_flip_channel_order is not None else self.do_flip_channel_order
+ )
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
+
+ do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
+
+ images = make_flat_list_of_images(images)
+
+ if segmentation_maps is not None:
+ segmentation_maps = make_flat_list_of_images(segmentation_maps, expected_ndims=2)
+
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if segmentation_maps is not None and not valid_images(segmentation_maps):
+ raise ValueError(
+ "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ images = [
+ self._preprocess_image(
+ image=img,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_flip_channel_order=do_flip_channel_order,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ for img in images
+ ]
+
+ data = {"pixel_values": images}
+
+ if segmentation_maps is not None:
+ segmentation_maps = [
+ self._preprocess_mask(
+ segmentation_map=segmentation_map,
+ do_reduce_labels=do_reduce_labels,
+ do_resize=do_resize,
+ size=size,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ input_data_format=input_data_format,
+ )
+ for segmentation_map in segmentation_maps
+ ]
+
+ data["labels"] = segmentation_maps
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileViT
+ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
+ """
+ Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
+
+ Args:
+ outputs ([`MobileViTForSemanticSegmentation`]):
+ Raw outputs of the model.
+ target_sizes (`list[Tuple]` of length `batch_size`, *optional*):
+ List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
+ predictions will not be resized.
+
+ Returns:
+ semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic
+ segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
+ specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
+ """
+ # TODO: add support for other frameworks
+ logits = outputs.logits
+
+ # Resize logits and compute semantic segmentation maps
+ if target_sizes is not None:
+ if len(logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ if is_torch_tensor(target_sizes):
+ target_sizes = target_sizes.numpy()
+
+ semantic_segmentation = []
+
+ for idx in range(len(logits)):
+ resized_logits = torch.nn.functional.interpolate(
+ logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
+ )
+ semantic_map = resized_logits[0].argmax(dim=0)
+ semantic_segmentation.append(semantic_map)
+ else:
+ semantic_segmentation = logits.argmax(dim=1)
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
+
+ return semantic_segmentation
+
+
+__all__ = ["MobileViTImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/image_processing_mobilevit_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/image_processing_mobilevit_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..fab16ecfdc878f3b6b9f3ecdae97eb474d8792d6
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/image_processing_mobilevit_fast.py
@@ -0,0 +1,246 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for MobileViT."""
+
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+ is_torch_tensor,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ auto_docstring,
+)
+
+
+class MobileVitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`):
+ Whether to flip the color channels from RGB to BGR or vice versa.
+ do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
+ is used for background, and background itself is not included in all classes of a dataset (e.g.
+ ADE20k). The background label will be replaced by 255.
+ """
+
+ do_flip_channel_order: Optional[bool]
+ do_reduce_labels: Optional[bool]
+
+
+@auto_docstring
+class MobileViTImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ size = {"shortest_edge": 224}
+ default_to_square = False
+ crop_size = {"height": 256, "width": 256}
+ do_resize = True
+ do_center_crop = True
+ do_rescale = True
+ do_normalize = None
+ do_convert_rgb = None
+ do_flip_channel_order = True
+ do_reduce_labels = False
+ valid_kwargs = MobileVitFastImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[MobileVitFastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ # Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.reduce_label
+ def reduce_label(self, labels: list["torch.Tensor"]):
+ for idx in range(len(labels)):
+ label = labels[idx]
+ label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label)
+ label = label - 1
+ label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label)
+ labels[idx] = label
+
+ return label
+
+ @auto_docstring
+ def preprocess(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput] = None,
+ **kwargs: Unpack[MobileVitFastImageProcessorKwargs],
+ ) -> BatchFeature:
+ r"""
+ segmentation_maps (`ImageInput`, *optional*):
+ The segmentation maps to preprocess.
+ """
+ return super().preprocess(images, segmentation_maps, **kwargs)
+
+ def _preprocess_image_like_inputs(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput],
+ do_convert_rgb: bool,
+ input_data_format: ChannelDimension,
+ device: Optional[Union[str, "torch.device"]] = None,
+ **kwargs: Unpack[MobileVitFastImageProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Preprocess image-like inputs.
+ """
+ images = self._prepare_image_like_inputs(
+ images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
+ )
+ images_kwargs = kwargs.copy()
+ images_kwargs["do_reduce_labels"] = False
+ batch_feature = self._preprocess(images, **images_kwargs)
+
+ if segmentation_maps is not None:
+ processed_segmentation_maps = self._prepare_image_like_inputs(
+ images=segmentation_maps,
+ expected_ndims=2,
+ do_convert_rgb=False,
+ input_data_format=ChannelDimension.FIRST,
+ )
+
+ segmentation_maps_kwargs = kwargs.copy()
+ segmentation_maps_kwargs.update(
+ {
+ "do_rescale": False,
+ "do_flip_channel_order": False,
+ # Nearest interpolation is used for segmentation maps instead of BILINEAR.
+ "interpolation": F.InterpolationMode.NEAREST_EXACT,
+ }
+ )
+
+ processed_segmentation_maps = self._preprocess(
+ images=processed_segmentation_maps, **segmentation_maps_kwargs
+ ).pixel_values
+ batch_feature["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64)
+
+ return batch_feature
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_reduce_labels: bool,
+ do_resize: bool,
+ size: Optional[SizeDict],
+ interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: Optional[float],
+ do_center_crop: bool,
+ crop_size: Optional[SizeDict],
+ do_flip_channel_order: bool,
+ disable_grouping: bool,
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ processed_images = []
+
+ if do_reduce_labels:
+ images = self.reduce_label(images)
+
+ # Group images by shape for more efficient batch processing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+
+ # Process each group of images with the same shape
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
+ resized_images_grouped[shape] = stacked_images
+
+ # Reorder images to original sequence
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group again after resizing (in case resize produced different sizes)
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+
+ for shape, stacked_images in grouped_images.items():
+ if do_center_crop:
+ stacked_images = self.center_crop(image=stacked_images, size=crop_size)
+ if do_rescale:
+ stacked_images = self.rescale(image=stacked_images, scale=rescale_factor)
+ if do_flip_channel_order:
+ # For batched images, we need to handle them all at once
+ if stacked_images.ndim > 3 and stacked_images.shape[1] >= 3:
+ # Flip RGB → BGR for batched images
+ flipped = stacked_images.clone()
+ flipped[:, 0:3] = stacked_images[:, [2, 1, 0], ...]
+ stacked_images = flipped
+
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+
+ # Stack all processed images if return_tensors is specified
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
+ """
+ Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
+
+ Args:
+ outputs ([`MobileNetV2ForSemanticSegmentation`]):
+ Raw outputs of the model.
+ target_sizes (`list[Tuple]` of length `batch_size`, *optional*):
+ List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
+ predictions will not be resized.
+
+ Returns:
+ semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic
+ segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
+ specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
+ """
+ logits = outputs.logits
+
+ # Resize logits and compute semantic segmentation maps
+ if target_sizes is not None:
+ if len(logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ if is_torch_tensor(target_sizes):
+ target_sizes = target_sizes.numpy()
+
+ semantic_segmentation = []
+
+ for idx in range(len(logits)):
+ resized_logits = torch.nn.functional.interpolate(
+ logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
+ )
+ semantic_map = resized_logits[0].argmax(dim=0)
+ semantic_segmentation.append(semantic_map)
+ else:
+ semantic_segmentation = logits.argmax(dim=1)
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
+
+ return semantic_segmentation
+
+
+__all__ = ["MobileViTImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/modeling_mobilevit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/modeling_mobilevit.py
new file mode 100644
index 0000000000000000000000000000000000000000..415c33a7cb8527a9df0439a62fc5026cd5d92e00
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/modeling_mobilevit.py
@@ -0,0 +1,998 @@
+# coding=utf-8
+# Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE
+"""PyTorch MobileViT model."""
+
+import math
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithNoAttention,
+ BaseModelOutputWithPoolingAndNoAttention,
+ ImageClassifierOutputWithNoAttention,
+ SemanticSegmenterOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import auto_docstring, logging, torch_int
+from .configuration_mobilevit import MobileViTConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int:
+ """
+ Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the
+ original TensorFlow repo. It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ """
+ if min_value is None:
+ min_value = divisor
+ new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_value < 0.9 * value:
+ new_value += divisor
+ return int(new_value)
+
+
+class MobileViTConvLayer(nn.Module):
+ def __init__(
+ self,
+ config: MobileViTConfig,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ groups: int = 1,
+ bias: bool = False,
+ dilation: int = 1,
+ use_normalization: bool = True,
+ use_activation: Union[bool, str] = True,
+ ) -> None:
+ super().__init__()
+ padding = int((kernel_size - 1) / 2) * dilation
+
+ if in_channels % groups != 0:
+ raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.")
+ if out_channels % groups != 0:
+ raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
+
+ self.convolution = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ padding_mode="zeros",
+ )
+
+ if use_normalization:
+ self.normalization = nn.BatchNorm2d(
+ num_features=out_channels,
+ eps=1e-5,
+ momentum=0.1,
+ affine=True,
+ track_running_stats=True,
+ )
+ else:
+ self.normalization = None
+
+ if use_activation:
+ if isinstance(use_activation, str):
+ self.activation = ACT2FN[use_activation]
+ elif isinstance(config.hidden_act, str):
+ self.activation = ACT2FN[config.hidden_act]
+ else:
+ self.activation = config.hidden_act
+ else:
+ self.activation = None
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ features = self.convolution(features)
+ if self.normalization is not None:
+ features = self.normalization(features)
+ if self.activation is not None:
+ features = self.activation(features)
+ return features
+
+
+class MobileViTInvertedResidual(nn.Module):
+ """
+ Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381
+ """
+
+ def __init__(
+ self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1
+ ) -> None:
+ super().__init__()
+ expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
+
+ if stride not in [1, 2]:
+ raise ValueError(f"Invalid stride {stride}.")
+
+ self.use_residual = (stride == 1) and (in_channels == out_channels)
+
+ self.expand_1x1 = MobileViTConvLayer(
+ config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1
+ )
+
+ self.conv_3x3 = MobileViTConvLayer(
+ config,
+ in_channels=expanded_channels,
+ out_channels=expanded_channels,
+ kernel_size=3,
+ stride=stride,
+ groups=expanded_channels,
+ dilation=dilation,
+ )
+
+ self.reduce_1x1 = MobileViTConvLayer(
+ config,
+ in_channels=expanded_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ use_activation=False,
+ )
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ residual = features
+
+ features = self.expand_1x1(features)
+ features = self.conv_3x3(features)
+ features = self.reduce_1x1(features)
+
+ return residual + features if self.use_residual else features
+
+
+class MobileViTMobileNetLayer(nn.Module):
+ def __init__(
+ self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1
+ ) -> None:
+ super().__init__()
+
+ self.layer = nn.ModuleList()
+ for i in range(num_stages):
+ layer = MobileViTInvertedResidual(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride if i == 0 else 1,
+ )
+ self.layer.append(layer)
+ in_channels = out_channels
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ for layer_module in self.layer:
+ features = layer_module(features)
+ return features
+
+
+class MobileViTSelfAttention(nn.Module):
+ def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
+ super().__init__()
+
+ if hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size {hidden_size} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, seq_length, _ = hidden_states.shape
+ query_layer = (
+ self.query(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+ key_layer = (
+ self.key(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+ value_layer = (
+ self.value(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+ return context_layer
+
+
+class MobileViTSelfOutput(nn.Module):
+ def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
+ super().__init__()
+ self.dense = nn.Linear(hidden_size, hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class MobileViTAttention(nn.Module):
+ def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
+ super().__init__()
+ self.attention = MobileViTSelfAttention(config, hidden_size)
+ self.output = MobileViTSelfOutput(config, hidden_size)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ self_outputs = self.attention(hidden_states)
+ attention_output = self.output(self_outputs)
+ return attention_output
+
+
+class MobileViTIntermediate(nn.Module):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
+ super().__init__()
+ self.dense = nn.Linear(hidden_size, intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class MobileViTOutput(nn.Module):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
+ super().__init__()
+ self.dense = nn.Linear(intermediate_size, hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states + input_tensor
+ return hidden_states
+
+
+class MobileViTTransformerLayer(nn.Module):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
+ super().__init__()
+ self.attention = MobileViTAttention(config, hidden_size)
+ self.intermediate = MobileViTIntermediate(config, hidden_size, intermediate_size)
+ self.output = MobileViTOutput(config, hidden_size, intermediate_size)
+ self.layernorm_before = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ attention_output = self.attention(self.layernorm_before(hidden_states))
+ hidden_states = attention_output + hidden_states
+
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+ layer_output = self.output(layer_output, hidden_states)
+ return layer_output
+
+
+class MobileViTTransformer(nn.Module):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int) -> None:
+ super().__init__()
+
+ self.layer = nn.ModuleList()
+ for _ in range(num_stages):
+ transformer_layer = MobileViTTransformerLayer(
+ config,
+ hidden_size=hidden_size,
+ intermediate_size=int(hidden_size * config.mlp_ratio),
+ )
+ self.layer.append(transformer_layer)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for layer_module in self.layer:
+ hidden_states = layer_module(hidden_states)
+ return hidden_states
+
+
+class MobileViTLayer(GradientCheckpointingLayer):
+ """
+ MobileViT block: https://huggingface.co/papers/2110.02178
+ """
+
+ def __init__(
+ self,
+ config: MobileViTConfig,
+ in_channels: int,
+ out_channels: int,
+ stride: int,
+ hidden_size: int,
+ num_stages: int,
+ dilation: int = 1,
+ ) -> None:
+ super().__init__()
+ self.patch_width = config.patch_size
+ self.patch_height = config.patch_size
+
+ if stride == 2:
+ self.downsampling_layer = MobileViTInvertedResidual(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride if dilation == 1 else 1,
+ dilation=dilation // 2 if dilation > 1 else 1,
+ )
+ in_channels = out_channels
+ else:
+ self.downsampling_layer = None
+
+ self.conv_kxk = MobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=config.conv_kernel_size,
+ )
+
+ self.conv_1x1 = MobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=hidden_size,
+ kernel_size=1,
+ use_normalization=False,
+ use_activation=False,
+ )
+
+ self.transformer = MobileViTTransformer(
+ config,
+ hidden_size=hidden_size,
+ num_stages=num_stages,
+ )
+
+ self.layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+
+ self.conv_projection = MobileViTConvLayer(
+ config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1
+ )
+
+ self.fusion = MobileViTConvLayer(
+ config, in_channels=2 * in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size
+ )
+
+ def unfolding(self, features: torch.Tensor) -> tuple[torch.Tensor, dict]:
+ patch_width, patch_height = self.patch_width, self.patch_height
+ patch_area = int(patch_width * patch_height)
+
+ batch_size, channels, orig_height, orig_width = features.shape
+
+ new_height = (
+ torch_int(torch.ceil(orig_height / patch_height) * patch_height)
+ if torch.jit.is_tracing()
+ else int(math.ceil(orig_height / patch_height) * patch_height)
+ )
+ new_width = (
+ torch_int(torch.ceil(orig_width / patch_width) * patch_width)
+ if torch.jit.is_tracing()
+ else int(math.ceil(orig_width / patch_width) * patch_width)
+ )
+
+ interpolate = False
+ if new_width != orig_width or new_height != orig_height:
+ # Note: Padding can be done, but then it needs to be handled in attention function.
+ features = nn.functional.interpolate(
+ features, size=(new_height, new_width), mode="bilinear", align_corners=False
+ )
+ interpolate = True
+
+ # number of patches along width and height
+ num_patch_width = new_width // patch_width
+ num_patch_height = new_height // patch_height
+ num_patches = num_patch_height * num_patch_width
+
+ # convert from shape (batch_size, channels, orig_height, orig_width)
+ # to the shape (batch_size * patch_area, num_patches, channels)
+ patches = features.reshape(
+ batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width
+ )
+ patches = patches.transpose(1, 2)
+ patches = patches.reshape(batch_size, channels, num_patches, patch_area)
+ patches = patches.transpose(1, 3)
+ patches = patches.reshape(batch_size * patch_area, num_patches, -1)
+
+ info_dict = {
+ "orig_size": (orig_height, orig_width),
+ "batch_size": batch_size,
+ "channels": channels,
+ "interpolate": interpolate,
+ "num_patches": num_patches,
+ "num_patches_width": num_patch_width,
+ "num_patches_height": num_patch_height,
+ }
+ return patches, info_dict
+
+ def folding(self, patches: torch.Tensor, info_dict: dict) -> torch.Tensor:
+ patch_width, patch_height = self.patch_width, self.patch_height
+ patch_area = int(patch_width * patch_height)
+
+ batch_size = info_dict["batch_size"]
+ channels = info_dict["channels"]
+ num_patches = info_dict["num_patches"]
+ num_patch_height = info_dict["num_patches_height"]
+ num_patch_width = info_dict["num_patches_width"]
+
+ # convert from shape (batch_size * patch_area, num_patches, channels)
+ # back to shape (batch_size, channels, orig_height, orig_width)
+ features = patches.contiguous().view(batch_size, patch_area, num_patches, -1)
+ features = features.transpose(1, 3)
+ features = features.reshape(
+ batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width
+ )
+ features = features.transpose(1, 2)
+ features = features.reshape(
+ batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width
+ )
+
+ if info_dict["interpolate"]:
+ features = nn.functional.interpolate(
+ features, size=info_dict["orig_size"], mode="bilinear", align_corners=False
+ )
+
+ return features
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ # reduce spatial dimensions if needed
+ if self.downsampling_layer:
+ features = self.downsampling_layer(features)
+
+ residual = features
+
+ # local representation
+ features = self.conv_kxk(features)
+ features = self.conv_1x1(features)
+
+ # convert feature map to patches
+ patches, info_dict = self.unfolding(features)
+
+ # learn global representations
+ patches = self.transformer(patches)
+ patches = self.layernorm(patches)
+
+ # convert patches back to feature maps
+ features = self.folding(patches, info_dict)
+
+ features = self.conv_projection(features)
+ features = self.fusion(torch.cat((residual, features), dim=1))
+ return features
+
+
+class MobileViTEncoder(nn.Module):
+ def __init__(self, config: MobileViTConfig) -> None:
+ super().__init__()
+ self.config = config
+
+ self.layer = nn.ModuleList()
+ self.gradient_checkpointing = False
+
+ # segmentation architectures like DeepLab and PSPNet modify the strides
+ # of the classification backbones
+ dilate_layer_4 = dilate_layer_5 = False
+ if config.output_stride == 8:
+ dilate_layer_4 = True
+ dilate_layer_5 = True
+ elif config.output_stride == 16:
+ dilate_layer_5 = True
+
+ dilation = 1
+
+ layer_1 = MobileViTMobileNetLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[0],
+ out_channels=config.neck_hidden_sizes[1],
+ stride=1,
+ num_stages=1,
+ )
+ self.layer.append(layer_1)
+
+ layer_2 = MobileViTMobileNetLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[1],
+ out_channels=config.neck_hidden_sizes[2],
+ stride=2,
+ num_stages=3,
+ )
+ self.layer.append(layer_2)
+
+ layer_3 = MobileViTLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[2],
+ out_channels=config.neck_hidden_sizes[3],
+ stride=2,
+ hidden_size=config.hidden_sizes[0],
+ num_stages=2,
+ )
+ self.layer.append(layer_3)
+
+ if dilate_layer_4:
+ dilation *= 2
+
+ layer_4 = MobileViTLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[3],
+ out_channels=config.neck_hidden_sizes[4],
+ stride=2,
+ hidden_size=config.hidden_sizes[1],
+ num_stages=4,
+ dilation=dilation,
+ )
+ self.layer.append(layer_4)
+
+ if dilate_layer_5:
+ dilation *= 2
+
+ layer_5 = MobileViTLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[4],
+ out_channels=config.neck_hidden_sizes[5],
+ stride=2,
+ hidden_size=config.hidden_sizes[2],
+ num_stages=3,
+ dilation=dilation,
+ )
+ self.layer.append(layer_5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutputWithNoAttention]:
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, layer_module in enumerate(self.layer):
+ hidden_states = layer_module(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
+
+
+@auto_docstring
+class MobileViTPreTrainedModel(PreTrainedModel):
+ config: MobileViTConfig
+ base_model_prefix = "mobilevit"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["MobileViTLayer"]
+
+ def _init_weights(self, module: nn.Module) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class MobileViTModel(MobileViTPreTrainedModel):
+ def __init__(self, config: MobileViTConfig, expand_output: bool = True):
+ r"""
+ expand_output (`bool`, *optional*, defaults to `True`):
+ Whether to expand the output of the model using a 1x1 convolution. If `True`, the model will apply an additional
+ 1x1 convolution to expand the output channels from `config.neck_hidden_sizes[5]` to `config.neck_hidden_sizes[6]`.
+ """
+ super().__init__(config)
+ self.config = config
+ self.expand_output = expand_output
+
+ self.conv_stem = MobileViTConvLayer(
+ config,
+ in_channels=config.num_channels,
+ out_channels=config.neck_hidden_sizes[0],
+ kernel_size=3,
+ stride=2,
+ )
+
+ self.encoder = MobileViTEncoder(config)
+
+ if self.expand_output:
+ self.conv_1x1_exp = MobileViTConvLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[5],
+ out_channels=config.neck_hidden_sizes[6],
+ kernel_size=1,
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _prune_heads(self, heads_to_prune):
+ """Prunes heads of the model.
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
+ """
+ for layer_index, heads in heads_to_prune.items():
+ mobilevit_layer = self.encoder.layer[layer_index]
+ if isinstance(mobilevit_layer, MobileViTLayer):
+ for transformer_layer in mobilevit_layer.transformer.layer:
+ transformer_layer.attention.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.conv_stem(pixel_values)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.expand_output:
+ last_hidden_state = self.conv_1x1_exp(encoder_outputs[0])
+
+ # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels)
+ pooled_output = torch.mean(last_hidden_state, dim=[-2, -1], keepdim=False)
+ else:
+ last_hidden_state = encoder_outputs[0]
+ pooled_output = None
+
+ if not return_dict:
+ output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
+ return output + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """
+)
+class MobileViTForImageClassification(MobileViTPreTrainedModel):
+ def __init__(self, config: MobileViTConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.mobilevit = MobileViTModel(config)
+
+ # Classifier head
+ self.dropout = nn.Dropout(config.classifier_dropout_prob, inplace=True)
+ self.classifier = (
+ nn.Linear(config.neck_hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.mobilevit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.classifier(self.dropout(pooled_output))
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutputWithNoAttention(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ )
+
+
+class MobileViTASPPPooling(nn.Module):
+ def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int) -> None:
+ super().__init__()
+
+ self.global_pool = nn.AdaptiveAvgPool2d(output_size=1)
+
+ self.conv_1x1 = MobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ use_normalization=True,
+ use_activation="relu",
+ )
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ spatial_size = features.shape[-2:]
+ features = self.global_pool(features)
+ features = self.conv_1x1(features)
+ features = nn.functional.interpolate(features, size=spatial_size, mode="bilinear", align_corners=False)
+ return features
+
+
+class MobileViTASPP(nn.Module):
+ """
+ ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587
+ """
+
+ def __init__(self, config: MobileViTConfig) -> None:
+ super().__init__()
+
+ in_channels = config.neck_hidden_sizes[-2]
+ out_channels = config.aspp_out_channels
+
+ if len(config.atrous_rates) != 3:
+ raise ValueError("Expected 3 values for atrous_rates")
+
+ self.convs = nn.ModuleList()
+
+ in_projection = MobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ use_activation="relu",
+ )
+ self.convs.append(in_projection)
+
+ self.convs.extend(
+ [
+ MobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ dilation=rate,
+ use_activation="relu",
+ )
+ for rate in config.atrous_rates
+ ]
+ )
+
+ pool_layer = MobileViTASPPPooling(config, in_channels, out_channels)
+ self.convs.append(pool_layer)
+
+ self.project = MobileViTConvLayer(
+ config, in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu"
+ )
+
+ self.dropout = nn.Dropout(p=config.aspp_dropout_prob)
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ pyramid = []
+ for conv in self.convs:
+ pyramid.append(conv(features))
+ pyramid = torch.cat(pyramid, dim=1)
+
+ pooled_features = self.project(pyramid)
+ pooled_features = self.dropout(pooled_features)
+ return pooled_features
+
+
+class MobileViTDeepLabV3(nn.Module):
+ """
+ DeepLabv3 architecture: https://huggingface.co/papers/1706.05587
+ """
+
+ def __init__(self, config: MobileViTConfig) -> None:
+ super().__init__()
+ self.aspp = MobileViTASPP(config)
+
+ self.dropout = nn.Dropout2d(config.classifier_dropout_prob)
+
+ self.classifier = MobileViTConvLayer(
+ config,
+ in_channels=config.aspp_out_channels,
+ out_channels=config.num_labels,
+ kernel_size=1,
+ use_normalization=False,
+ use_activation=False,
+ bias=True,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ features = self.aspp(hidden_states[-1])
+ features = self.dropout(features)
+ features = self.classifier(features)
+ return features
+
+
+@auto_docstring(
+ custom_intro="""
+ MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC.
+ """
+)
+class MobileViTForSemanticSegmentation(MobileViTPreTrainedModel):
+ def __init__(self, config: MobileViTConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.mobilevit = MobileViTModel(config, expand_output=False)
+ self.segmentation_head = MobileViTDeepLabV3(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, SemanticSegmenterOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
+
+ Examples:
+
+ ```python
+ >>> import requests
+ >>> import torch
+ >>> from PIL import Image
+ >>> from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small")
+ >>> model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> # logits are of shape (batch_size, num_labels, height, width)
+ >>> logits = outputs.logits
+ ```"""
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None and self.config.num_labels == 1:
+ raise ValueError("The number of labels should be greater than one")
+
+ outputs = self.mobilevit(
+ pixel_values,
+ output_hidden_states=True, # we need the intermediate hidden states
+ return_dict=return_dict,
+ )
+
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ logits = self.segmentation_head(encoder_hidden_states)
+
+ loss = None
+ if labels is not None:
+ # upsample logits to the images' original size
+ upsampled_logits = nn.functional.interpolate(
+ logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
+ )
+ loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
+ loss = loss_fct(upsampled_logits, labels)
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (logits,) + outputs[1:]
+ else:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SemanticSegmenterOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=None,
+ )
+
+
+__all__ = [
+ "MobileViTForImageClassification",
+ "MobileViTForSemanticSegmentation",
+ "MobileViTModel",
+ "MobileViTPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/modeling_tf_mobilevit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/modeling_tf_mobilevit.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcad0f302a8ebfdc0679f140eb6dd7e139f215c2
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mobilevit/modeling_tf_mobilevit.py
@@ -0,0 +1,1376 @@
+# coding=utf-8
+# Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE
+"""TensorFlow 2.0 MobileViT model."""
+
+from __future__ import annotations
+
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...file_utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ replace_return_docstrings,
+)
+from ...modeling_tf_outputs import (
+ TFBaseModelOutput,
+ TFBaseModelOutputWithPooling,
+ TFImageClassifierOutputWithNoAttention,
+ TFSemanticSegmenterOutputWithNoAttention,
+)
+from ...modeling_tf_utils import (
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list, stable_softmax
+from ...utils import logging
+from .configuration_mobilevit import MobileViTConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "MobileViTConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "apple/mobilevit-small"
+_EXPECTED_OUTPUT_SHAPE = [1, 640, 8, 8]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "apple/mobilevit-small"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+def make_divisible(value: int, divisor: int = 8, min_value: int | None = None) -> int:
+ """
+ Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the
+ original TensorFlow repo. It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ """
+ if min_value is None:
+ min_value = divisor
+ new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_value < 0.9 * value:
+ new_value += divisor
+ return int(new_value)
+
+
+class TFMobileViTConvLayer(keras.layers.Layer):
+ def __init__(
+ self,
+ config: MobileViTConfig,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ groups: int = 1,
+ bias: bool = False,
+ dilation: int = 1,
+ use_normalization: bool = True,
+ use_activation: bool | str = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ logger.warning(
+ f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
+ "to train/fine-tune this model, you need a GPU or a TPU"
+ )
+
+ padding = int((kernel_size - 1) / 2) * dilation
+ self.padding = keras.layers.ZeroPadding2D(padding)
+
+ if out_channels % groups != 0:
+ raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
+
+ self.convolution = keras.layers.Conv2D(
+ filters=out_channels,
+ kernel_size=kernel_size,
+ strides=stride,
+ padding="VALID",
+ dilation_rate=dilation,
+ groups=groups,
+ use_bias=bias,
+ name="convolution",
+ )
+
+ if use_normalization:
+ self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
+ else:
+ self.normalization = None
+
+ if use_activation:
+ if isinstance(use_activation, str):
+ self.activation = get_tf_activation(use_activation)
+ elif isinstance(config.hidden_act, str):
+ self.activation = get_tf_activation(config.hidden_act)
+ else:
+ self.activation = config.hidden_act
+ else:
+ self.activation = None
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
+ padded_features = self.padding(features)
+ features = self.convolution(padded_features)
+ if self.normalization is not None:
+ features = self.normalization(features, training=training)
+ if self.activation is not None:
+ features = self.activation(features)
+ return features
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "convolution", None) is not None:
+ with tf.name_scope(self.convolution.name):
+ self.convolution.build([None, None, None, self.in_channels])
+ if getattr(self, "normalization", None) is not None:
+ if hasattr(self.normalization, "name"):
+ with tf.name_scope(self.normalization.name):
+ self.normalization.build([None, None, None, self.out_channels])
+
+
+class TFMobileViTInvertedResidual(keras.layers.Layer):
+ """
+ Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381
+ """
+
+ def __init__(
+ self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1, **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
+
+ if stride not in [1, 2]:
+ raise ValueError(f"Invalid stride {stride}.")
+
+ self.use_residual = (stride == 1) and (in_channels == out_channels)
+
+ self.expand_1x1 = TFMobileViTConvLayer(
+ config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1, name="expand_1x1"
+ )
+
+ self.conv_3x3 = TFMobileViTConvLayer(
+ config,
+ in_channels=expanded_channels,
+ out_channels=expanded_channels,
+ kernel_size=3,
+ stride=stride,
+ groups=expanded_channels,
+ dilation=dilation,
+ name="conv_3x3",
+ )
+
+ self.reduce_1x1 = TFMobileViTConvLayer(
+ config,
+ in_channels=expanded_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ use_activation=False,
+ name="reduce_1x1",
+ )
+
+ def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
+ residual = features
+
+ features = self.expand_1x1(features, training=training)
+ features = self.conv_3x3(features, training=training)
+ features = self.reduce_1x1(features, training=training)
+
+ return residual + features if self.use_residual else features
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "expand_1x1", None) is not None:
+ with tf.name_scope(self.expand_1x1.name):
+ self.expand_1x1.build(None)
+ if getattr(self, "conv_3x3", None) is not None:
+ with tf.name_scope(self.conv_3x3.name):
+ self.conv_3x3.build(None)
+ if getattr(self, "reduce_1x1", None) is not None:
+ with tf.name_scope(self.reduce_1x1.name):
+ self.reduce_1x1.build(None)
+
+
+class TFMobileViTMobileNetLayer(keras.layers.Layer):
+ def __init__(
+ self,
+ config: MobileViTConfig,
+ in_channels: int,
+ out_channels: int,
+ stride: int = 1,
+ num_stages: int = 1,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.layers = []
+ for i in range(num_stages):
+ layer = TFMobileViTInvertedResidual(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride if i == 0 else 1,
+ name=f"layer.{i}",
+ )
+ self.layers.append(layer)
+ in_channels = out_channels
+
+ def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
+ for layer_module in self.layers:
+ features = layer_module(features, training=training)
+ return features
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layers", None) is not None:
+ for layer_module in self.layers:
+ with tf.name_scope(layer_module.name):
+ layer_module.build(None)
+
+
+class TFMobileViTSelfAttention(keras.layers.Layer):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ if hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size {hidden_size} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ scale = tf.cast(self.attention_head_size, dtype=tf.float32)
+ self.scale = tf.math.sqrt(scale)
+
+ self.query = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="query")
+ self.key = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="key")
+ self.value = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="value")
+
+ self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
+ self.hidden_size = hidden_size
+
+ def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
+ batch_size = tf.shape(x)[0]
+ x = tf.reshape(x, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
+ return tf.transpose(x, perm=[0, 2, 1, 3])
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ batch_size = tf.shape(hidden_states)[0]
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+ attention_scores = attention_scores / self.scale
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = stable_softmax(attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs, training=training)
+
+ context_layer = tf.matmul(attention_probs, value_layer)
+
+ context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
+ context_layer = tf.reshape(context_layer, shape=(batch_size, -1, self.all_head_size))
+ return context_layer
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "query", None) is not None:
+ with tf.name_scope(self.query.name):
+ self.query.build([None, None, self.hidden_size])
+ if getattr(self, "key", None) is not None:
+ with tf.name_scope(self.key.name):
+ self.key.build([None, None, self.hidden_size])
+ if getattr(self, "value", None) is not None:
+ with tf.name_scope(self.value.name):
+ self.value.build([None, None, self.hidden_size])
+
+
+class TFMobileViTSelfOutput(keras.layers.Layer):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.dense = keras.layers.Dense(hidden_size, name="dense")
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+ self.hidden_size = hidden_size
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.hidden_size])
+
+
+class TFMobileViTAttention(keras.layers.Layer):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.attention = TFMobileViTSelfAttention(config, hidden_size, name="attention")
+ self.dense_output = TFMobileViTSelfOutput(config, hidden_size, name="output")
+
+ def prune_heads(self, heads):
+ raise NotImplementedError
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ self_outputs = self.attention(hidden_states, training=training)
+ attention_output = self.dense_output(self_outputs, training=training)
+ return attention_output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "dense_output", None) is not None:
+ with tf.name_scope(self.dense_output.name):
+ self.dense_output.build(None)
+
+
+class TFMobileViTIntermediate(keras.layers.Layer):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.dense = keras.layers.Dense(intermediate_size, name="dense")
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
+ else:
+ self.intermediate_act_fn = config.hidden_act
+ self.hidden_size = hidden_size
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.hidden_size])
+
+
+class TFMobileViTOutput(keras.layers.Layer):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.dense = keras.layers.Dense(hidden_size, name="dense")
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+ self.intermediate_size = intermediate_size
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = hidden_states + input_tensor
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.intermediate_size])
+
+
+class TFMobileViTTransformerLayer(keras.layers.Layer):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.attention = TFMobileViTAttention(config, hidden_size, name="attention")
+ self.intermediate = TFMobileViTIntermediate(config, hidden_size, intermediate_size, name="intermediate")
+ self.mobilevit_output = TFMobileViTOutput(config, hidden_size, intermediate_size, name="output")
+ self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
+ self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
+ self.hidden_size = hidden_size
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ attention_output = self.attention(self.layernorm_before(hidden_states), training=training)
+ hidden_states = attention_output + hidden_states
+
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+ layer_output = self.mobilevit_output(layer_output, hidden_states, training=training)
+ return layer_output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "intermediate", None) is not None:
+ with tf.name_scope(self.intermediate.name):
+ self.intermediate.build(None)
+ if getattr(self, "mobilevit_output", None) is not None:
+ with tf.name_scope(self.mobilevit_output.name):
+ self.mobilevit_output.build(None)
+ if getattr(self, "layernorm_before", None) is not None:
+ with tf.name_scope(self.layernorm_before.name):
+ self.layernorm_before.build([None, None, self.hidden_size])
+ if getattr(self, "layernorm_after", None) is not None:
+ with tf.name_scope(self.layernorm_after.name):
+ self.layernorm_after.build([None, None, self.hidden_size])
+
+
+class TFMobileViTTransformer(keras.layers.Layer):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ self.layers = []
+ for i in range(num_stages):
+ transformer_layer = TFMobileViTTransformerLayer(
+ config,
+ hidden_size=hidden_size,
+ intermediate_size=int(hidden_size * config.mlp_ratio),
+ name=f"layer.{i}",
+ )
+ self.layers.append(transformer_layer)
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ for layer_module in self.layers:
+ hidden_states = layer_module(hidden_states, training=training)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layers", None) is not None:
+ for layer_module in self.layers:
+ with tf.name_scope(layer_module.name):
+ layer_module.build(None)
+
+
+class TFMobileViTLayer(keras.layers.Layer):
+ """
+ MobileViT block: https://huggingface.co/papers/2110.02178
+ """
+
+ def __init__(
+ self,
+ config: MobileViTConfig,
+ in_channels: int,
+ out_channels: int,
+ stride: int,
+ hidden_size: int,
+ num_stages: int,
+ dilation: int = 1,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.patch_width = config.patch_size
+ self.patch_height = config.patch_size
+
+ if stride == 2:
+ self.downsampling_layer = TFMobileViTInvertedResidual(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride if dilation == 1 else 1,
+ dilation=dilation // 2 if dilation > 1 else 1,
+ name="downsampling_layer",
+ )
+ in_channels = out_channels
+ else:
+ self.downsampling_layer = None
+
+ self.conv_kxk = TFMobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=config.conv_kernel_size,
+ name="conv_kxk",
+ )
+
+ self.conv_1x1 = TFMobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=hidden_size,
+ kernel_size=1,
+ use_normalization=False,
+ use_activation=False,
+ name="conv_1x1",
+ )
+
+ self.transformer = TFMobileViTTransformer(
+ config, hidden_size=hidden_size, num_stages=num_stages, name="transformer"
+ )
+
+ self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+
+ self.conv_projection = TFMobileViTConvLayer(
+ config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1, name="conv_projection"
+ )
+
+ self.fusion = TFMobileViTConvLayer(
+ config,
+ in_channels=2 * in_channels,
+ out_channels=in_channels,
+ kernel_size=config.conv_kernel_size,
+ name="fusion",
+ )
+ self.hidden_size = hidden_size
+
+ def unfolding(self, features: tf.Tensor) -> tuple[tf.Tensor, dict]:
+ patch_width, patch_height = self.patch_width, self.patch_height
+ patch_area = tf.cast(patch_width * patch_height, "int32")
+
+ batch_size = tf.shape(features)[0]
+ orig_height = tf.shape(features)[1]
+ orig_width = tf.shape(features)[2]
+ channels = tf.shape(features)[3]
+
+ new_height = tf.cast(tf.math.ceil(orig_height / patch_height) * patch_height, "int32")
+ new_width = tf.cast(tf.math.ceil(orig_width / patch_width) * patch_width, "int32")
+
+ interpolate = new_width != orig_width or new_height != orig_height
+ if interpolate:
+ # Note: Padding can be done, but then it needs to be handled in attention function.
+ features = tf.image.resize(features, size=(new_height, new_width), method="bilinear")
+
+ # number of patches along width and height
+ num_patch_width = new_width // patch_width
+ num_patch_height = new_height // patch_height
+ num_patches = num_patch_height * num_patch_width
+
+ # convert from shape (batch_size, orig_height, orig_width, channels)
+ # to the shape (batch_size * patch_area, num_patches, channels)
+ features = tf.transpose(features, [0, 3, 1, 2])
+ patches = tf.reshape(
+ features, (batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width)
+ )
+ patches = tf.transpose(patches, [0, 2, 1, 3])
+ patches = tf.reshape(patches, (batch_size, channels, num_patches, patch_area))
+ patches = tf.transpose(patches, [0, 3, 2, 1])
+ patches = tf.reshape(patches, (batch_size * patch_area, num_patches, channels))
+
+ info_dict = {
+ "orig_size": (orig_height, orig_width),
+ "batch_size": batch_size,
+ "channels": channels,
+ "interpolate": interpolate,
+ "num_patches": num_patches,
+ "num_patches_width": num_patch_width,
+ "num_patches_height": num_patch_height,
+ }
+ return patches, info_dict
+
+ def folding(self, patches: tf.Tensor, info_dict: dict) -> tf.Tensor:
+ patch_width, patch_height = self.patch_width, self.patch_height
+ patch_area = int(patch_width * patch_height)
+
+ batch_size = info_dict["batch_size"]
+ channels = info_dict["channels"]
+ num_patches = info_dict["num_patches"]
+ num_patch_height = info_dict["num_patches_height"]
+ num_patch_width = info_dict["num_patches_width"]
+
+ # convert from shape (batch_size * patch_area, num_patches, channels)
+ # back to shape (batch_size, channels, orig_height, orig_width)
+ features = tf.reshape(patches, (batch_size, patch_area, num_patches, -1))
+ features = tf.transpose(features, perm=(0, 3, 2, 1))
+ features = tf.reshape(
+ features, (batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width)
+ )
+ features = tf.transpose(features, perm=(0, 2, 1, 3))
+ features = tf.reshape(
+ features, (batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width)
+ )
+ features = tf.transpose(features, perm=(0, 2, 3, 1))
+
+ if info_dict["interpolate"]:
+ features = tf.image.resize(features, size=info_dict["orig_size"], method="bilinear")
+
+ return features
+
+ def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
+ # reduce spatial dimensions if needed
+ if self.downsampling_layer:
+ features = self.downsampling_layer(features, training=training)
+
+ residual = features
+
+ # local representation
+ features = self.conv_kxk(features, training=training)
+ features = self.conv_1x1(features, training=training)
+
+ # convert feature map to patches
+ patches, info_dict = self.unfolding(features)
+
+ # learn global representations
+ patches = self.transformer(patches, training=training)
+ patches = self.layernorm(patches)
+
+ # convert patches back to feature maps
+ features = self.folding(patches, info_dict)
+
+ features = self.conv_projection(features, training=training)
+ features = self.fusion(tf.concat([residual, features], axis=-1), training=training)
+ return features
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "conv_kxk", None) is not None:
+ with tf.name_scope(self.conv_kxk.name):
+ self.conv_kxk.build(None)
+ if getattr(self, "conv_1x1", None) is not None:
+ with tf.name_scope(self.conv_1x1.name):
+ self.conv_1x1.build(None)
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, None, self.hidden_size])
+ if getattr(self, "conv_projection", None) is not None:
+ with tf.name_scope(self.conv_projection.name):
+ self.conv_projection.build(None)
+ if getattr(self, "fusion", None) is not None:
+ with tf.name_scope(self.fusion.name):
+ self.fusion.build(None)
+ if getattr(self, "downsampling_layer", None) is not None:
+ with tf.name_scope(self.downsampling_layer.name):
+ self.downsampling_layer.build(None)
+
+
+class TFMobileViTEncoder(keras.layers.Layer):
+ def __init__(self, config: MobileViTConfig, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.config = config
+
+ self.layers = []
+
+ # segmentation architectures like DeepLab and PSPNet modify the strides
+ # of the classification backbones
+ dilate_layer_4 = dilate_layer_5 = False
+ if config.output_stride == 8:
+ dilate_layer_4 = True
+ dilate_layer_5 = True
+ elif config.output_stride == 16:
+ dilate_layer_5 = True
+
+ dilation = 1
+
+ layer_1 = TFMobileViTMobileNetLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[0],
+ out_channels=config.neck_hidden_sizes[1],
+ stride=1,
+ num_stages=1,
+ name="layer.0",
+ )
+ self.layers.append(layer_1)
+
+ layer_2 = TFMobileViTMobileNetLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[1],
+ out_channels=config.neck_hidden_sizes[2],
+ stride=2,
+ num_stages=3,
+ name="layer.1",
+ )
+ self.layers.append(layer_2)
+
+ layer_3 = TFMobileViTLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[2],
+ out_channels=config.neck_hidden_sizes[3],
+ stride=2,
+ hidden_size=config.hidden_sizes[0],
+ num_stages=2,
+ name="layer.2",
+ )
+ self.layers.append(layer_3)
+
+ if dilate_layer_4:
+ dilation *= 2
+
+ layer_4 = TFMobileViTLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[3],
+ out_channels=config.neck_hidden_sizes[4],
+ stride=2,
+ hidden_size=config.hidden_sizes[1],
+ num_stages=4,
+ dilation=dilation,
+ name="layer.3",
+ )
+ self.layers.append(layer_4)
+
+ if dilate_layer_5:
+ dilation *= 2
+
+ layer_5 = TFMobileViTLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[4],
+ out_channels=config.neck_hidden_sizes[5],
+ stride=2,
+ hidden_size=config.hidden_sizes[2],
+ num_stages=3,
+ dilation=dilation,
+ name="layer.4",
+ )
+ self.layers.append(layer_5)
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ training: bool = False,
+ ) -> tuple | TFBaseModelOutput:
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, layer_module in enumerate(self.layers):
+ hidden_states = layer_module(hidden_states, training=training)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+
+ return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layers", None) is not None:
+ for layer_module in self.layers:
+ with tf.name_scope(layer_module.name):
+ layer_module.build(None)
+
+
+@keras_serializable
+class TFMobileViTMainLayer(keras.layers.Layer):
+ config_class = MobileViTConfig
+
+ def __init__(self, config: MobileViTConfig, expand_output: bool = True, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.expand_output = expand_output
+
+ self.conv_stem = TFMobileViTConvLayer(
+ config,
+ in_channels=config.num_channels,
+ out_channels=config.neck_hidden_sizes[0],
+ kernel_size=3,
+ stride=2,
+ name="conv_stem",
+ )
+
+ self.encoder = TFMobileViTEncoder(config, name="encoder")
+
+ if self.expand_output:
+ self.conv_1x1_exp = TFMobileViTConvLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[5],
+ out_channels=config.neck_hidden_sizes[6],
+ kernel_size=1,
+ name="conv_1x1_exp",
+ )
+
+ self.pooler = keras.layers.GlobalAveragePooling2D(data_format="channels_first", name="pooler")
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ raise NotImplementedError
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> tuple[tf.Tensor] | TFBaseModelOutputWithPooling:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
+ # So change the input format from `NCHW` to `NHWC`.
+ # shape = (batch_size, in_height, in_width, in_channels=num_channels)
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+
+ embedding_output = self.conv_stem(pixel_values, training=training)
+
+ encoder_outputs = self.encoder(
+ embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
+ )
+
+ if self.expand_output:
+ last_hidden_state = self.conv_1x1_exp(encoder_outputs[0])
+
+ # Change to NCHW output format to have uniformity in the modules
+ last_hidden_state = tf.transpose(last_hidden_state, perm=[0, 3, 1, 2])
+
+ # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels)
+ pooled_output = self.pooler(last_hidden_state)
+ else:
+ last_hidden_state = encoder_outputs[0]
+ # Change to NCHW output format to have uniformity in the modules
+ last_hidden_state = tf.transpose(last_hidden_state, perm=[0, 3, 1, 2])
+ pooled_output = None
+
+ if not return_dict:
+ output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
+
+ # Change to NCHW output format to have uniformity in the modules
+ if not self.expand_output:
+ remaining_encoder_outputs = encoder_outputs[1:]
+ remaining_encoder_outputs = tuple(
+ tf.transpose(h, perm=(0, 3, 1, 2)) for h in remaining_encoder_outputs[0]
+ )
+ remaining_encoder_outputs = (remaining_encoder_outputs,)
+ return output + remaining_encoder_outputs
+ else:
+ return output + encoder_outputs[1:]
+
+ # Change the other hidden state outputs to NCHW as well
+ if output_hidden_states:
+ hidden_states = tuple(tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1])
+
+ return TFBaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "conv_stem", None) is not None:
+ with tf.name_scope(self.conv_stem.name):
+ self.conv_stem.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "pooler", None) is not None:
+ with tf.name_scope(self.pooler.name):
+ self.pooler.build([None, None, None, None])
+ if getattr(self, "conv_1x1_exp", None) is not None:
+ with tf.name_scope(self.conv_1x1_exp.name):
+ self.conv_1x1_exp.build(None)
+
+
+class TFMobileViTPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = MobileViTConfig
+ base_model_prefix = "mobilevit"
+ main_input_name = "pixel_values"
+
+
+MOBILEVIT_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`MobileViTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+MOBILEVIT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]`, `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`MobileViTImageProcessor.__call__`] for details.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+"""
+
+
+@add_start_docstrings(
+ "The bare MobileViT model outputting raw hidden-states without any specific head on top.",
+ MOBILEVIT_START_DOCSTRING,
+)
+class TFMobileViTModel(TFMobileViTPreTrainedModel):
+ def __init__(self, config: MobileViTConfig, expand_output: bool = True, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.config = config
+ self.expand_output = expand_output
+
+ self.mobilevit = TFMobileViTMainLayer(config, expand_output=expand_output, name="mobilevit")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> tuple[tf.Tensor] | TFBaseModelOutputWithPooling:
+ output = self.mobilevit(pixel_values, output_hidden_states, return_dict, training=training)
+ return output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "mobilevit", None) is not None:
+ with tf.name_scope(self.mobilevit.name):
+ self.mobilevit.build(None)
+
+
+@add_start_docstrings(
+ """
+ MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """,
+ MOBILEVIT_START_DOCSTRING,
+)
+class TFMobileViTForImageClassification(TFMobileViTPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: MobileViTConfig, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+ self.mobilevit = TFMobileViTMainLayer(config, name="mobilevit")
+
+ # Classifier head
+ self.dropout = keras.layers.Dropout(config.classifier_dropout_prob)
+ self.classifier = (
+ keras.layers.Dense(config.num_labels, name="classifier") if config.num_labels > 0 else tf.identity
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=TFImageClassifierOutputWithNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ output_hidden_states: bool | None = None,
+ labels: tf.Tensor | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> tuple | TFImageClassifierOutputWithNoAttention:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.mobilevit(
+ pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
+ )
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.classifier(self.dropout(pooled_output, training=training))
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "mobilevit", None) is not None:
+ with tf.name_scope(self.mobilevit.name):
+ self.mobilevit.build(None)
+ if getattr(self, "classifier", None) is not None:
+ if hasattr(self.classifier, "name"):
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.neck_hidden_sizes[-1]])
+
+
+class TFMobileViTASPPPooling(keras.layers.Layer):
+ def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ self.global_pool = keras.layers.GlobalAveragePooling2D(keepdims=True, name="global_pool")
+
+ self.conv_1x1 = TFMobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ use_normalization=True,
+ use_activation="relu",
+ name="conv_1x1",
+ )
+
+ def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
+ spatial_size = shape_list(features)[1:-1]
+ features = self.global_pool(features)
+ features = self.conv_1x1(features, training=training)
+ features = tf.image.resize(features, size=spatial_size, method="bilinear")
+ return features
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "global_pool", None) is not None:
+ with tf.name_scope(self.global_pool.name):
+ self.global_pool.build([None, None, None, None])
+ if getattr(self, "conv_1x1", None) is not None:
+ with tf.name_scope(self.conv_1x1.name):
+ self.conv_1x1.build(None)
+
+
+class TFMobileViTASPP(keras.layers.Layer):
+ """
+ ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587
+ """
+
+ def __init__(self, config: MobileViTConfig, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ in_channels = config.neck_hidden_sizes[-2]
+ out_channels = config.aspp_out_channels
+
+ if len(config.atrous_rates) != 3:
+ raise ValueError("Expected 3 values for atrous_rates")
+
+ self.convs = []
+
+ in_projection = TFMobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ use_activation="relu",
+ name="convs.0",
+ )
+ self.convs.append(in_projection)
+
+ self.convs.extend(
+ [
+ TFMobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ dilation=rate,
+ use_activation="relu",
+ name=f"convs.{i + 1}",
+ )
+ for i, rate in enumerate(config.atrous_rates)
+ ]
+ )
+
+ pool_layer = TFMobileViTASPPPooling(
+ config, in_channels, out_channels, name=f"convs.{len(config.atrous_rates) + 1}"
+ )
+ self.convs.append(pool_layer)
+
+ self.project = TFMobileViTConvLayer(
+ config,
+ in_channels=5 * out_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ use_activation="relu",
+ name="project",
+ )
+
+ self.dropout = keras.layers.Dropout(config.aspp_dropout_prob)
+
+ def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
+ # since the hidden states were transposed to have `(batch_size, channels, height, width)`
+ # layout we transpose them back to have `(batch_size, height, width, channels)` layout.
+ features = tf.transpose(features, perm=[0, 2, 3, 1])
+ pyramid = []
+ for conv in self.convs:
+ pyramid.append(conv(features, training=training))
+ pyramid = tf.concat(pyramid, axis=-1)
+
+ pooled_features = self.project(pyramid, training=training)
+ pooled_features = self.dropout(pooled_features, training=training)
+ return pooled_features
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "project", None) is not None:
+ with tf.name_scope(self.project.name):
+ self.project.build(None)
+ if getattr(self, "convs", None) is not None:
+ for conv in self.convs:
+ with tf.name_scope(conv.name):
+ conv.build(None)
+
+
+class TFMobileViTDeepLabV3(keras.layers.Layer):
+ """
+ DeepLabv3 architecture: https://huggingface.co/papers/1706.05587
+ """
+
+ def __init__(self, config: MobileViTConfig, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.aspp = TFMobileViTASPP(config, name="aspp")
+
+ self.dropout = keras.layers.Dropout(config.classifier_dropout_prob)
+
+ self.classifier = TFMobileViTConvLayer(
+ config,
+ in_channels=config.aspp_out_channels,
+ out_channels=config.num_labels,
+ kernel_size=1,
+ use_normalization=False,
+ use_activation=False,
+ bias=True,
+ name="classifier",
+ )
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ features = self.aspp(hidden_states[-1], training=training)
+ features = self.dropout(features, training=training)
+ features = self.classifier(features, training=training)
+ return features
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "aspp", None) is not None:
+ with tf.name_scope(self.aspp.name):
+ self.aspp.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build(None)
+
+
+@add_start_docstrings(
+ """
+ MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC.
+ """,
+ MOBILEVIT_START_DOCSTRING,
+)
+class TFMobileViTForSemanticSegmentation(TFMobileViTPreTrainedModel):
+ def __init__(self, config: MobileViTConfig, **kwargs) -> None:
+ super().__init__(config, **kwargs)
+
+ self.num_labels = config.num_labels
+ self.mobilevit = TFMobileViTMainLayer(config, expand_output=False, name="mobilevit")
+ self.segmentation_head = TFMobileViTDeepLabV3(config, name="segmentation_head")
+
+ def hf_compute_loss(self, logits, labels):
+ # upsample logits to the images' original size
+ # `labels` is of shape (batch_size, height, width)
+ label_interp_shape = shape_list(labels)[1:]
+
+ upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
+ # compute weighted loss
+ loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
+
+ def masked_loss(real, pred):
+ unmasked_loss = loss_fct(real, pred)
+ mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype)
+ masked_loss = unmasked_loss * mask
+ # Reduction strategy in the similar spirit with
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210
+ reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask)
+ return tf.reshape(reduced_masked_loss, (1,))
+
+ return masked_loss(labels, upsampled_logits)
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFSemanticSegmenterOutputWithNoAttention, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ labels: tf.Tensor | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> tuple | TFSemanticSegmenterOutputWithNoAttention:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, TFMobileViTForSemanticSegmentation
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small")
+ >>> model = TFMobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")
+
+ >>> inputs = image_processor(images=image, return_tensors="tf")
+
+ >>> outputs = model(**inputs)
+
+ >>> # logits are of shape (batch_size, num_labels, height, width)
+ >>> logits = outputs.logits
+ ```"""
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None and not self.config.num_labels > 1:
+ raise ValueError("The number of labels should be greater than one")
+
+ outputs = self.mobilevit(
+ pixel_values,
+ output_hidden_states=True, # we need the intermediate hidden states
+ return_dict=return_dict,
+ training=training,
+ )
+
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ logits = self.segmentation_head(encoder_hidden_states, training=training)
+
+ loss = None
+ if labels is not None:
+ loss = self.hf_compute_loss(logits=logits, labels=labels)
+
+ # make logits of shape (batch_size, num_labels, height, width) to
+ # keep them consistent across APIs
+ logits = tf.transpose(logits, perm=[0, 3, 1, 2])
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (logits,) + outputs[1:]
+ else:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSemanticSegmenterOutputWithNoAttention(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "mobilevit", None) is not None:
+ with tf.name_scope(self.mobilevit.name):
+ self.mobilevit.build(None)
+ if getattr(self, "segmentation_head", None) is not None:
+ with tf.name_scope(self.segmentation_head.name):
+ self.segmentation_head.build(None)
+
+
+__all__ = [
+ "TFMobileViTForImageClassification",
+ "TFMobileViTForSemanticSegmentation",
+ "TFMobileViTModel",
+ "TFMobileViTPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..18317742981909460973138da806584ddfc4a390
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_modernbert import *
+ from .modeling_modernbert import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert/configuration_modernbert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert/configuration_modernbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b0da20ad203cb02375f40e14ea818b642c26c91
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert/configuration_modernbert.py
@@ -0,0 +1,224 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_modernbert.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Literal
+
+from ...configuration_utils import PretrainedConfig
+
+
+class ModernBertConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the ModernBERT-base.
+ e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50368):
+ Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ModernBertModel`]
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 1152):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 22):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
+ if not specified.
+ max_position_embeddings (`int`, *optional*, defaults to 8192):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
+ The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
+ norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ norm_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in the normalization layers.
+ pad_token_id (`int`, *optional*, defaults to 50283):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 50282):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 50281):
+ Beginning of stream token id.
+ cls_token_id (`int`, *optional*, defaults to 50281):
+ Classification token id.
+ sep_token_id (`int`, *optional*, defaults to 50282):
+ Separation token id.
+ global_rope_theta (`float`, *optional*, defaults to 160000.0):
+ The base period of the global RoPE embeddings.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ global_attn_every_n_layers (`int`, *optional*, defaults to 3):
+ The number of layers between global attention layers.
+ local_attention (`int`, *optional*, defaults to 128):
+ The window size for local attention.
+ local_rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the local RoPE embeddings.
+ embedding_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the embeddings.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in the MLP layers.
+ mlp_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the MLP layers.
+ decoder_bias (`bool`, *optional*, defaults to `True`):
+ Whether to use bias in the decoder layers.
+ classifier_pooling (`str`, *optional*, defaults to `"cls"`):
+ The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the
+ CLS token doesn't attend to all tokens on long sequences.
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the classifier.
+ classifier_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in the classifier.
+ classifier_activation (`str`, *optional*, defaults to `"gelu"`):
+ The activation function for the classifier.
+ deterministic_flash_attn (`bool`, *optional*, defaults to `False`):
+ Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic.
+ sparse_prediction (`bool`, *optional*, defaults to `False`):
+ Whether to use sparse prediction for the masked language model instead of returning the full dense logits.
+ sparse_pred_ignore_index (`int`, *optional*, defaults to -100):
+ The index to ignore for the sparse prediction.
+ reference_compile (`bool`, *optional*):
+ Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
+ the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
+ shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
+ be faster in some scenarios.
+ repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
+ When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
+ applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
+
+ Examples:
+
+ ```python
+ >>> from transformers import ModernBertModel, ModernBertConfig
+
+ >>> # Initializing a ModernBert style configuration
+ >>> configuration = ModernBertConfig()
+
+ >>> # Initializing a model from the modernbert-base style configuration
+ >>> model = ModernBertModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "modernbert"
+ attribute_map = {"rope_theta": "global_rope_theta"}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=50368,
+ hidden_size=768,
+ intermediate_size=1152,
+ num_hidden_layers=22,
+ num_attention_heads=12,
+ hidden_activation="gelu",
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ initializer_cutoff_factor=2.0,
+ norm_eps=1e-5,
+ norm_bias=False,
+ pad_token_id=50283,
+ eos_token_id=50282,
+ bos_token_id=50281,
+ cls_token_id=50281,
+ sep_token_id=50282,
+ global_rope_theta=160000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ global_attn_every_n_layers=3,
+ local_attention=128,
+ local_rope_theta=10000.0,
+ embedding_dropout=0.0,
+ mlp_bias=False,
+ mlp_dropout=0.0,
+ decoder_bias=True,
+ classifier_pooling: Literal["cls", "mean"] = "cls",
+ classifier_dropout=0.0,
+ classifier_bias=False,
+ classifier_activation="gelu",
+ deterministic_flash_attn=False,
+ sparse_prediction=False,
+ sparse_pred_ignore_index=-100,
+ reference_compile=None,
+ repad_logits_with_grad=False,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ cls_token_id=cls_token_id,
+ sep_token_id=sep_token_id,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.initializer_range = initializer_range
+ self.initializer_cutoff_factor = initializer_cutoff_factor
+ self.norm_eps = norm_eps
+ self.norm_bias = norm_bias
+ self.global_rope_theta = global_rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.hidden_activation = hidden_activation
+ self.global_attn_every_n_layers = global_attn_every_n_layers
+ self.local_attention = local_attention
+ self.local_rope_theta = local_rope_theta
+ self.embedding_dropout = embedding_dropout
+ self.mlp_bias = mlp_bias
+ self.mlp_dropout = mlp_dropout
+ self.decoder_bias = decoder_bias
+ self.classifier_pooling = classifier_pooling
+ self.classifier_dropout = classifier_dropout
+ self.classifier_bias = classifier_bias
+ self.classifier_activation = classifier_activation
+ self.deterministic_flash_attn = deterministic_flash_attn
+ self.sparse_prediction = sparse_prediction
+ self.sparse_pred_ignore_index = sparse_pred_ignore_index
+ self.reference_compile = reference_compile
+ self.repad_logits_with_grad = repad_logits_with_grad
+
+ if self.classifier_pooling not in ["cls", "mean"]:
+ raise ValueError(
+ f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
+ )
+
+ def to_dict(self):
+ output = super().to_dict()
+ output.pop("reference_compile", None)
+ return output
+
+
+__all__ = ["ModernBertConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..00fbe19c3a63c7932d8710b38466d6720a4d46a3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py
@@ -0,0 +1,1572 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_modernbert.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import math
+from contextlib import nullcontext
+from typing import Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, is_flash_attn_2_available, logging
+from ...utils.import_utils import is_triton_available
+from .configuration_modernbert import ModernBertConfig
+
+
+if is_flash_attn_2_available():
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
+ from flash_attn.layers.rotary import RotaryEmbedding
+ from flash_attn.ops.triton.rotary import apply_rotary
+else:
+ RotaryEmbedding = object
+
+
+logger = logging.get_logger(__name__)
+
+
+class ApplyRotaryEmbUnpad(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ qkv,
+ cos,
+ sin,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ ):
+ # (total_nnz, 3, nheads, headdim)
+ qkv = qkv.contiguous()
+ total_nnz, _three, _nheads, headdim = qkv.shape
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
+ # we get the same tensor
+ # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
+ qk = qkv[:, :2].view(total_nnz, -1, headdim)
+ apply_rotary(
+ qk,
+ cos,
+ sin,
+ seqlen_offsets=0,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ interleaved=False,
+ inplace=True,
+ )
+
+ ctx.save_for_backward(cos, sin, cu_seqlens)
+ ctx.max_seqlen = max_seqlen
+ return qkv
+
+ @staticmethod
+ def backward(ctx, do):
+ cos, sin, cu_seqlens = ctx.saved_tensors
+ do = do.contiguous()
+ total_nnz, _three, _nheads, headdim = do.shape
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
+ # we get the same tensor
+ dqk = do[:, :2].view(total_nnz, -1, headdim)
+ apply_rotary(
+ dqk,
+ cos,
+ sin,
+ seqlen_offsets=0,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=ctx.max_seqlen,
+ interleaved=False,
+ inplace=True,
+ conjugate=True,
+ )
+
+ return do, None, None, None, None, None, None
+
+
+def apply_rotary_unpadded(
+ qkv,
+ cos,
+ sin,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+):
+ """
+ Arguments:
+ qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
+ of 1st half and 2nd half (GPT-NeoX style).
+ inplace: if True, apply rotary embedding in-place.
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
+ Most commonly used in inference when we have KV cache.
+ cu_seqlens: (batch + 1,) or None
+ max_seqlen: int
+ Return:
+ out: (total_nnz, dim)
+ rotary_dim must be <= headdim
+ Apply rotary embedding to the first rotary_dim of x.
+ """
+ return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
+
+
+class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
+ """
+ The rotary position embeddings applied directly to unpadded sequences.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ base: float = 10000.0,
+ max_seqlen: Optional[int] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
+ up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
+ the cos_sin_cache will be recomputed during the forward pass.
+ """
+ super().__init__(dim=dim, base=base, device=device, interleaved=False)
+ self.max_seqlen = max_seqlen
+
+ if max_seqlen is not None and device is not None and dtype is not None:
+ self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
+
+ def forward(
+ self,
+ qkv: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ max_seqlen: Optional[int] = None,
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Apply rotary embedding *inplace* to qkv.
+ qkv: (total_nnz, 3, nheads, headdim)
+ cu_seqlens: (batch + 1,) cumulative sequence lengths
+ max_seqlen: int max seq length in the batch
+ """
+ if max_seqlen is not None:
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
+
+ qkv = apply_rotary_unpadded(
+ qkv,
+ self._cos_cached,
+ self._sin_cached,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+
+ return qkv
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
+
+
+class ModernBertEmbeddings(nn.Module):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+ """
+
+ def __init__(self, config: ModernBertConfig):
+ super().__init__()
+ self.config = config
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ self.drop = nn.Dropout(config.embedding_dropout)
+
+ @torch.compile(dynamic=True)
+ def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
+ return self.drop(self.norm(self.tok_embeddings(input_ids)))
+
+ def forward(
+ self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ if inputs_embeds is not None:
+ hidden_states = self.drop(self.norm(inputs_embeds))
+ else:
+ hidden_states = (
+ self.compiled_embeddings(input_ids)
+ if self.config.reference_compile
+ else self.drop(self.norm(self.tok_embeddings(input_ids)))
+ )
+ return hidden_states
+
+
+class ModernBertMLP(nn.Module):
+ """Applies the GLU at the end of each ModernBERT layer.
+
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
+ """
+
+ def __init__(self, config: ModernBertConfig):
+ super().__init__()
+ self.config = config
+ self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias)
+ self.act = ACT2FN[config.hidden_activation]
+ self.drop = nn.Dropout(config.mlp_dropout)
+ self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
+ return self.Wo(self.drop(self.act(input) * gate))
+
+
+class ModernBertRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: ModernBertConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def eager_attention_forward(
+ module: "ModernBertAttention",
+ qkv: torch.Tensor,
+ attention_mask: torch.Tensor,
+ sliding_window_mask: torch.Tensor,
+ position_ids: Optional[torch.LongTensor],
+ local_attention: tuple[int, int],
+ bs: int,
+ dim: int,
+ output_attentions: Optional[bool] = False,
+ **_kwargs,
+) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
+ # qkv: [batch_size, seqlen, 3, nheads, headdim]
+ cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
+ query, key, value = qkv.transpose(3, 1).unbind(dim=2)
+ # query, key, value: [batch_size, heads, seq_len, head_dim]
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
+
+ scale = module.head_dim**-0.5
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
+
+ if local_attention != (-1, -1):
+ attention_mask = sliding_window_mask
+
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bs, -1, dim)
+ if output_attentions:
+ return (attn_output, attn_weights)
+ return (attn_output,)
+
+
+def flash_attention_forward(
+ module: "ModernBertAttention",
+ qkv: torch.Tensor,
+ rotary_emb: ModernBertUnpaddedRotaryEmbedding,
+ cu_seqlens: torch.Tensor,
+ max_seqlen: int,
+ local_attention: tuple[int, int],
+ bs: int,
+ dim: int,
+ target_dtype: torch.dtype = torch.bfloat16,
+ **_kwargs,
+) -> tuple[torch.Tensor]:
+ # (total_seqlen, 3, nheads, headdim)
+ qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
+
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
+ if convert_dtype:
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
+ orig_dtype = qkv.dtype
+ qkv = qkv.to(target_dtype)
+
+ attn = flash_attn_varlen_qkvpacked_func(
+ qkv,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ dropout_p=module.attention_dropout if module.training else 0.0,
+ deterministic=module.deterministic_flash_attn,
+ window_size=local_attention,
+ )
+ attn = attn.to(orig_dtype) # type: ignore
+ else:
+ attn = flash_attn_varlen_qkvpacked_func(
+ qkv,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ dropout_p=module.attention_dropout if module.training else 0.0,
+ deterministic=module.deterministic_flash_attn,
+ window_size=local_attention,
+ )
+ return (attn.view(bs, dim),)
+
+
+def sdpa_attention_forward(
+ module: "ModernBertAttention",
+ qkv: torch.Tensor,
+ attention_mask: torch.Tensor,
+ sliding_window_mask: torch.Tensor,
+ position_ids: Optional[torch.LongTensor],
+ local_attention: tuple[int, int],
+ bs: int,
+ dim: int,
+ **_kwargs,
+) -> tuple[torch.Tensor]:
+ # qkv: [batch_size, seqlen, 3, nheads, headdim]
+ cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
+ query, key, value = qkv.transpose(3, 1).unbind(dim=2)
+ # query, key, value: [batch_size, heads, seq_len, head_dim]
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
+
+ if local_attention != (-1, -1):
+ attention_mask = sliding_window_mask
+
+ attn_output = (
+ F.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ dropout_p=module.attention_dropout if module.training else 0.0,
+ attn_mask=attention_mask,
+ )
+ .transpose(1, 2)
+ .contiguous()
+ )
+ attn_output = attn_output.view(bs, -1, dim)
+ return (attn_output,)
+
+
+MODERNBERT_ATTENTION_FUNCTION = {
+ "flash_attention_2": flash_attention_forward,
+ "eager": eager_attention_forward,
+ "sdpa": sdpa_attention_forward,
+}
+
+
+class ModernBertAttention(nn.Module):
+ """Performs multi-headed self attention on a batch of unpadded sequences.
+
+ If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
+ If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
+ which requires padding and unpadding inputs, adding some overhead.
+
+ See `forward` method for additional details.
+ """
+
+ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_id = layer_id
+
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.deterministic_flash_attn = config.deterministic_flash_attn
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.hidden_size // config.num_attention_heads
+ self.all_head_size = self.head_dim * self.num_heads
+ self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
+
+ if layer_id % config.global_attn_every_n_layers != 0:
+ self.local_attention = (config.local_attention // 2, config.local_attention // 2)
+ rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta
+ max_position_embeddings = config.local_attention
+ else:
+ self.local_attention = (-1, -1)
+ max_position_embeddings = config.max_position_embeddings
+ rope_theta = config.global_rope_theta
+
+ if config._attn_implementation == "flash_attention_2":
+ self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
+ dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
+ )
+ else:
+ config_copy = copy.deepcopy(config)
+ config_copy.rope_theta = rope_theta
+ self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
+
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
+ self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
+ self.pruned_heads = set()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ **kwargs,
+ ) -> torch.Tensor:
+ qkv = self.Wqkv(hidden_states)
+
+ bs = hidden_states.shape[0]
+ if self.config._attn_implementation == "flash_attention_2":
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
+ else:
+ qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
+
+ attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
+ self,
+ qkv=qkv,
+ rotary_emb=self.rotary_emb,
+ local_attention=self.local_attention,
+ bs=bs,
+ dim=self.all_head_size,
+ output_attentions=output_attentions,
+ **kwargs,
+ )
+ hidden_states = attn_outputs[0]
+ hidden_states = self.out_drop(self.Wo(hidden_states))
+
+ return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
+
+
+class ModernBertEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ if layer_id == 0:
+ self.attn_norm = nn.Identity()
+ else:
+ self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ self.attn = ModernBertAttention(config=config, layer_id=layer_id)
+ self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ self.mlp = ModernBertMLP(config)
+
+ @torch.compile(dynamic=True)
+ def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.mlp(self.mlp_norm(hidden_states))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> torch.Tensor:
+ attn_outputs = self.attn(
+ self.attn_norm(hidden_states),
+ attention_mask=attention_mask,
+ sliding_window_mask=sliding_window_mask,
+ position_ids=position_ids,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ output_attentions=output_attentions,
+ )
+ hidden_states = hidden_states + attn_outputs[0]
+ mlp_output = (
+ self.compiled_mlp(hidden_states)
+ if self.config.reference_compile
+ else self.mlp(self.mlp_norm(hidden_states))
+ )
+ hidden_states = hidden_states + mlp_output
+
+ return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
+
+
+@auto_docstring
+class ModernBertPreTrainedModel(PreTrainedModel):
+ config: ModernBertConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = False
+
+ def _init_weights(self, module: nn.Module):
+ cutoff_factor = self.config.initializer_cutoff_factor
+ if cutoff_factor is None:
+ cutoff_factor = 3
+
+ def init_weight(module: nn.Module, std: float):
+ nn.init.trunc_normal_(
+ module.weight,
+ mean=0.0,
+ std=std,
+ a=-cutoff_factor * std,
+ b=cutoff_factor * std,
+ )
+
+ if isinstance(module, nn.Linear):
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ stds = {
+ "in": self.config.initializer_range,
+ "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
+ "embedding": self.config.initializer_range,
+ "final_out": self.config.hidden_size**-0.5,
+ }
+
+ if isinstance(module, ModernBertEmbeddings):
+ init_weight(module.tok_embeddings, stds["embedding"])
+ elif isinstance(module, ModernBertMLP):
+ init_weight(module.Wi, stds["in"])
+ init_weight(module.Wo, stds["out"])
+ elif isinstance(module, ModernBertAttention):
+ init_weight(module.Wqkv, stds["in"])
+ init_weight(module.Wo, stds["out"])
+ elif isinstance(module, ModernBertPredictionHead):
+ init_weight(module.dense, stds["out"])
+ elif isinstance(module, ModernBertForMaskedLM):
+ init_weight(module.decoder, stds["out"])
+ elif isinstance(
+ module,
+ (
+ ModernBertForSequenceClassification,
+ ModernBertForMultipleChoice,
+ ModernBertForTokenClassification,
+ ModernBertForQuestionAnswering,
+ ),
+ ):
+ init_weight(module.classifier, stds["final_out"])
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ def _check_and_adjust_attn_implementation(
+ self, attn_implementation: Optional[str], is_init_check: bool = False
+ ) -> str:
+ """
+ Checks and dispatches to hhe requested attention implementation.
+ """
+ # If the user didn't specify anything, try to use flash_attention_2 if available.
+ # Otherwise we fall back to the default SDPA -> Eager from the super() method.
+ # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
+ # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
+
+ try:
+ attn_implementation = (
+ "flash_attention_2"
+ if attn_implementation is None and self._flash_attn_2_can_dispatch()
+ else attn_implementation
+ )
+ except (ValueError, ImportError):
+ pass
+ return super()._check_and_adjust_attn_implementation(
+ attn_implementation=attn_implementation, is_init_check=is_init_check
+ )
+
+ def _maybe_set_compile(self):
+ if self.config.reference_compile is False:
+ return
+
+ if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
+ if self.config.reference_compile:
+ logger.warning_once(
+ "If `accelerate` split the model across devices, `torch.compile` will not work. "
+ "Falling back to non-compiled mode."
+ )
+ self.config.reference_compile = False
+
+ if self.device.type == "mps":
+ if self.config.reference_compile:
+ logger.warning_once(
+ "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
+ "Falling back to non-compiled mode."
+ )
+ self.config.reference_compile = False
+
+ if self.device.type == "cpu":
+ if self.config.reference_compile:
+ logger.warning_once(
+ "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
+ "Falling back to non-compiled mode."
+ )
+ self.config.reference_compile = False
+
+ if self.config.reference_compile is None:
+ self.config.reference_compile = is_triton_available()
+
+ def resize_token_embeddings(self, *args, **kwargs):
+ model_embeds = super().resize_token_embeddings(*args, **kwargs)
+
+ if self.config.reference_compile in {True, None}:
+ if self.config.reference_compile:
+ logger.warning_once(
+ "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
+ )
+ self.config.reference_compile = False
+
+ return model_embeds
+
+
+def _unpad_modernbert_input(
+ inputs: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_ids: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ """
+ Remove padding from input sequences.
+
+ Args:
+ inputs: (batch, seqlen, ...) or (batch, seqlen)
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
+ position_ids: (batch, seqlen), int, position ids
+ labels: (batch, seqlen), int, labels
+
+ Returns:
+ unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
+ indices: (total_nnz)
+ cu_seqlens: (batch + 1), the cumulative sequence lengths
+ max_seqlen_in_batch: int
+ unpadded_position_ids: (total_nnz) or None
+ unpadded_labels: (total_nnz) or None
+ """
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
+ cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+
+ if inputs.dim() == 2:
+ unpadded_inputs = inputs.flatten()[indices]
+ else:
+ batch, seqlen, *rest = inputs.shape
+ shape = batch * seqlen
+ unpadded_inputs = inputs.view(shape, *rest)[indices]
+
+ unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
+ unpadded_labels = labels.flatten()[indices] if labels is not None else None
+
+ return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
+
+
+def _pad_modernbert_output(
+ inputs: torch.Tensor,
+ indices: torch.Tensor,
+ batch: int,
+ seqlen: int,
+) -> torch.Tensor:
+ """
+ Add padding to sequences.
+
+ Args:
+ inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
+ indices: (total_nnz)
+ batch: int, batch size
+ seqlen: int, max sequence length
+
+ Returns:
+ padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
+ """
+ if inputs.dim() == 1:
+ output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
+ output[indices] = inputs
+ padded_inputs = output.view(batch, seqlen)
+ else:
+ _, *rest = inputs.shape
+ output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
+ output[indices] = inputs
+ padded_inputs = output.view(batch, seqlen, *rest)
+
+ return padded_inputs
+
+
+@auto_docstring
+class ModernBertModel(ModernBertPreTrainedModel):
+ def __init__(self, config: ModernBertConfig):
+ super().__init__(config)
+ self.config = config
+ self.embeddings = ModernBertEmbeddings(config)
+ self.layers = nn.ModuleList(
+ [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
+ )
+ self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ self.gradient_checkpointing = False
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.tok_embeddings = value
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ indices: Optional[torch.Tensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ batch_size: Optional[int] = None,
+ seq_len: Optional[int] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
+ r"""
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
+ far-away tokens in the local attention layers when not using Flash Attention.
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
+ max_seqlen (`int`, *optional*):
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
+ batch_size (`int`, *optional*):
+ Batch size of the input sequences. Used to pad the output tensors.
+ seq_len (`int`, *optional*):
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ self._maybe_set_compile()
+
+ if input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+
+ if batch_size is None and seq_len is None:
+ if inputs_embeds is not None:
+ batch_size, seq_len = inputs_embeds.shape[:2]
+ else:
+ batch_size, seq_len = input_ids.shape[:2]
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
+
+ repad = False
+ if self.config._attn_implementation == "flash_attention_2":
+ if indices is None and cu_seqlens is None and max_seqlen is None:
+ repad = True
+ if inputs_embeds is None:
+ with torch.no_grad():
+ input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
+ inputs=input_ids, attention_mask=attention_mask
+ )
+ else:
+ inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
+ inputs=inputs_embeds, attention_mask=attention_mask
+ )
+ else:
+ if position_ids is None:
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
+
+ attention_mask, sliding_window_mask = self._update_attention_mask(
+ attention_mask, output_attentions=output_attentions
+ )
+
+ hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
+
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ sliding_window_mask=sliding_window_mask,
+ position_ids=position_ids,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+ if output_attentions and len(layer_outputs) > 1:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states = self.final_norm(hidden_states)
+
+ if repad:
+ hidden_states = _pad_modernbert_output(
+ inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
+ )
+ if all_hidden_states is not None:
+ all_hidden_states = tuple(
+ _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
+ for hs in all_hidden_states
+ )
+ # If the attention implementation is FA2 and there is no need for repadding, there might still be the batch
+ # dimension missing
+ elif (
+ self.config._attn_implementation == "flash_attention_2"
+ and all_hidden_states is not None
+ and all_hidden_states[-1].dim() == 2
+ ):
+ hidden_states = hidden_states.unsqueeze(0)
+ all_hidden_states = tuple(hs.unsqueeze(0) for hs in all_hidden_states)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
+ if output_attentions:
+ if self.config._attn_implementation == "sdpa":
+ logger.warning_once(
+ "Outputting attentions is only supported with the 'eager' attention implementation, "
+ 'not with "sdpa". Falling back to `attn_implementation="eager"`.'
+ )
+ self.config._attn_implementation = "eager"
+ elif self.config._attn_implementation != "eager":
+ logger.warning_once(
+ "Outputting attentions is only supported with the eager attention implementation, "
+ f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
+ " Setting `output_attentions=False`."
+ )
+
+ global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
+
+ # Create position indices
+ rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
+ # Calculate distance between positions
+ distance = torch.abs(rows - rows.T)
+
+ # Create sliding window mask (1 for positions within window, 0 outside)
+ window_mask = (
+ (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
+ )
+ # Combine with existing mask
+ sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
+
+ return global_attention_mask, sliding_window_mask
+
+
+class ModernBertPredictionHead(nn.Module):
+ def __init__(self, config: ModernBertConfig):
+ super().__init__()
+ self.config = config
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
+ self.act = ACT2FN[config.classifier_activation]
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.norm(self.act(self.dense(hidden_states)))
+
+
+@auto_docstring(
+ custom_intro="""
+ The ModernBert Model with a decoder head on top that is used for masked language modeling.
+ """
+)
+class ModernBertForMaskedLM(ModernBertPreTrainedModel):
+ _tied_weights_keys = ["decoder.weight"]
+
+ def __init__(self, config: ModernBertConfig):
+ super().__init__(config)
+ self.config = config
+ self.model = ModernBertModel(config)
+ self.head = ModernBertPredictionHead(config)
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
+
+ self.sparse_prediction = self.config.sparse_prediction
+ self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.decoder
+
+ def set_output_embeddings(self, new_embeddings: nn.Linear):
+ self.decoder = new_embeddings
+
+ @torch.compile(dynamic=True)
+ def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
+ return self.decoder(self.head(output))
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ indices: Optional[torch.Tensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ batch_size: Optional[int] = None,
+ seq_len: Optional[int] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
+ r"""
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
+ far-away tokens in the local attention layers when not using Flash Attention.
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
+ max_seqlen (`int`, *optional*):
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
+ batch_size (`int`, *optional*):
+ Batch size of the input sequences. Used to pad the output tensors.
+ seq_len (`int`, *optional*):
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ self._maybe_set_compile()
+
+ if self.config._attn_implementation == "flash_attention_2":
+ if indices is None and cu_seqlens is None and max_seqlen is None:
+ if batch_size is None and seq_len is None:
+ if inputs_embeds is not None:
+ batch_size, seq_len = inputs_embeds.shape[:2]
+ else:
+ batch_size, seq_len = input_ids.shape[:2]
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
+
+ if inputs_embeds is None:
+ with torch.no_grad():
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
+ inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
+ )
+ else:
+ inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
+ inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ sliding_window_mask=sliding_window_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ indices=indices,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ batch_size=batch_size,
+ seq_len=seq_len,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = outputs[0]
+
+ if self.sparse_prediction and labels is not None:
+ # flatten labels and output first
+ labels = labels.view(-1)
+ last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
+
+ # then filter out the non-masked tokens
+ mask_tokens = labels != self.sparse_pred_ignore_index
+ last_hidden_state = last_hidden_state[mask_tokens]
+ labels = labels[mask_tokens]
+
+ logits = (
+ self.compiled_head(last_hidden_state)
+ if self.config.reference_compile
+ else self.decoder(self.head(last_hidden_state))
+ )
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ if self.config._attn_implementation == "flash_attention_2":
+ # Logits padding
+ with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
+ logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
+ # Hidden states padding
+ if getattr(outputs, "hidden_states", None) is not None:
+ padded_hidden_states = []
+ for hs in outputs.hidden_states:
+ if hs.dim() == 3 and hs.shape[0] == 1:
+ hs = hs.squeeze(0)
+ padded_hidden_states.append(
+ _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
+ )
+ outputs.hidden_states = tuple(padded_hidden_states)
+
+ if not return_dict:
+ output = (logits,)
+ return ((loss,) + output) if loss is not None else output
+
+ return MaskedLMOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The ModernBert Model with a sequence classification head on top that performs pooling.
+ """
+)
+class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
+ def __init__(self, config: ModernBertConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.model = ModernBertModel(config)
+ self.head = ModernBertPredictionHead(config)
+ self.drop = torch.nn.Dropout(config.classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ indices: Optional[torch.Tensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ batch_size: Optional[int] = None,
+ seq_len: Optional[int] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
+ r"""
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
+ far-away tokens in the local attention layers when not using Flash Attention.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
+ max_seqlen (`int`, *optional*):
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
+ batch_size (`int`, *optional*):
+ Batch size of the input sequences. Used to pad the output tensors.
+ seq_len (`int`, *optional*):
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ self._maybe_set_compile()
+
+ if input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+
+ if batch_size is None and seq_len is None:
+ if inputs_embeds is not None:
+ batch_size, seq_len = inputs_embeds.shape[:2]
+ else:
+ batch_size, seq_len = input_ids.shape[:2]
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ sliding_window_mask=sliding_window_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ indices=indices,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ batch_size=batch_size,
+ seq_len=seq_len,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = outputs[0]
+
+ if self.config.classifier_pooling == "cls":
+ last_hidden_state = last_hidden_state[:, 0]
+ elif self.config.classifier_pooling == "mean":
+ last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
+ dim=1, keepdim=True
+ )
+
+ pooled_output = self.head(last_hidden_state)
+ pooled_output = self.drop(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,)
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
+ """
+)
+class ModernBertForTokenClassification(ModernBertPreTrainedModel):
+ def __init__(self, config: ModernBertConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.model = ModernBertModel(config)
+ self.head = ModernBertPredictionHead(config)
+ self.drop = torch.nn.Dropout(config.classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ indices: Optional[torch.Tensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ batch_size: Optional[int] = None,
+ seq_len: Optional[int] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
+ far-away tokens in the local attention layers when not using Flash Attention.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
+ max_seqlen (`int`, *optional*):
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
+ batch_size (`int`, *optional*):
+ Batch size of the input sequences. Used to pad the output tensors.
+ seq_len (`int`, *optional*):
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ self._maybe_set_compile()
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ sliding_window_mask=sliding_window_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ indices=indices,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ batch_size=batch_size,
+ seq_len=seq_len,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = outputs[0]
+
+ last_hidden_state = self.head(last_hidden_state)
+ last_hidden_state = self.drop(last_hidden_state)
+ logits = self.classifier(last_hidden_state)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
+ def __init__(self, config: ModernBertConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.model = ModernBertModel(config)
+ self.head = ModernBertPredictionHead(config)
+ self.drop = torch.nn.Dropout(config.classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ indices: Optional[torch.Tensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ batch_size: Optional[int] = None,
+ seq_len: Optional[int] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
+ r"""
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
+ far-away tokens in the local attention layers when not using Flash Attention.
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
+ max_seqlen (`int`, *optional*):
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
+ batch_size (`int`, *optional*):
+ Batch size of the input sequences. Used to pad the output tensors.
+ seq_len (`int`, *optional*):
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ self._maybe_set_compile()
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ sliding_window_mask=sliding_window_mask,
+ position_ids=position_ids,
+ indices=indices,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ batch_size=batch_size,
+ seq_len=seq_len,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = outputs[0]
+
+ last_hidden_state = self.head(last_hidden_state)
+ last_hidden_state = self.drop(last_hidden_state)
+ logits = self.classifier(last_hidden_state)
+
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ loss = None
+ if start_positions is not None and end_positions is not None:
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
+ """
+)
+class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
+ def __init__(self, config: ModernBertConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.model = ModernBertModel(config)
+ self.head = ModernBertPredictionHead(config)
+ self.drop = torch.nn.Dropout(config.classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ indices: Optional[torch.Tensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ batch_size: Optional[int] = None,
+ seq_len: Optional[int] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
+ r"""
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
+ far-away tokens in the local attention layers when not using Flash Attention.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
+ max_seqlen (`int`, *optional*):
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
+ batch_size (`int`, *optional*):
+ Batch size of the input sequences. Used to pad the output tensors.
+ seq_len (`int`, *optional*):
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ self._maybe_set_compile()
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ sliding_window_mask=sliding_window_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ indices=indices,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ batch_size=batch_size,
+ seq_len=seq_len,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size)
+
+ # If classifier_pooling is "cls", isolate the token
+ if self.config.classifier_pooling == "cls":
+ indices_0 = torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device)
+ # for left or right padding, is the first non-pad token
+ if attention_mask is not None:
+ cls_mask = attention_mask.argmax(dim=-1).to(last_hidden_state.device)
+ # if no pad, is the first token
+ else:
+ cls_mask = torch.tensor(0, dtype=torch.long, device=last_hidden_state.device)
+ # extract the token for the logits
+ last_hidden_state = last_hidden_state[indices_0, cls_mask]
+
+ # If classifier_pooling is "mean", pool the hidden states by averaging over the sequence length
+ elif self.config.classifier_pooling == "mean":
+ num_non_pad_tokens = attention_mask.sum(dim=1, keepdim=True)
+ last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / num_non_pad_tokens
+
+ pooled_output = self.head(last_hidden_state)
+ pooled_output = self.drop(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = nn.CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "ModernBertModel",
+ "ModernBertPreTrainedModel",
+ "ModernBertForMaskedLM",
+ "ModernBertForSequenceClassification",
+ "ModernBertForTokenClassification",
+ "ModernBertForQuestionAnswering",
+ "ModernBertForMultipleChoice",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert/modular_modernbert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert/modular_modernbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac298f0959669c9f4201216d796df5291464e29
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert/modular_modernbert.py
@@ -0,0 +1,1698 @@
+# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import math
+from contextlib import nullcontext
+from typing import Literal, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...configuration_utils import PretrainedConfig
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, is_flash_attn_2_available, logging
+from ...utils.import_utils import is_triton_available
+from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb
+
+
+if is_flash_attn_2_available():
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
+ from flash_attn.layers.rotary import RotaryEmbedding
+ from flash_attn.ops.triton.rotary import apply_rotary
+else:
+ RotaryEmbedding = object
+
+
+logger = logging.get_logger(__name__)
+
+
+class ModernBertConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the ModernBERT-base.
+ e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50368):
+ Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ModernBertModel`]
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 1152):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 22):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
+ if not specified.
+ max_position_embeddings (`int`, *optional*, defaults to 8192):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
+ The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
+ norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ norm_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in the normalization layers.
+ pad_token_id (`int`, *optional*, defaults to 50283):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 50282):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 50281):
+ Beginning of stream token id.
+ cls_token_id (`int`, *optional*, defaults to 50281):
+ Classification token id.
+ sep_token_id (`int`, *optional*, defaults to 50282):
+ Separation token id.
+ global_rope_theta (`float`, *optional*, defaults to 160000.0):
+ The base period of the global RoPE embeddings.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ global_attn_every_n_layers (`int`, *optional*, defaults to 3):
+ The number of layers between global attention layers.
+ local_attention (`int`, *optional*, defaults to 128):
+ The window size for local attention.
+ local_rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the local RoPE embeddings.
+ embedding_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the embeddings.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in the MLP layers.
+ mlp_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the MLP layers.
+ decoder_bias (`bool`, *optional*, defaults to `True`):
+ Whether to use bias in the decoder layers.
+ classifier_pooling (`str`, *optional*, defaults to `"cls"`):
+ The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the
+ CLS token doesn't attend to all tokens on long sequences.
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the classifier.
+ classifier_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in the classifier.
+ classifier_activation (`str`, *optional*, defaults to `"gelu"`):
+ The activation function for the classifier.
+ deterministic_flash_attn (`bool`, *optional*, defaults to `False`):
+ Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic.
+ sparse_prediction (`bool`, *optional*, defaults to `False`):
+ Whether to use sparse prediction for the masked language model instead of returning the full dense logits.
+ sparse_pred_ignore_index (`int`, *optional*, defaults to -100):
+ The index to ignore for the sparse prediction.
+ reference_compile (`bool`, *optional*):
+ Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
+ the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
+ shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
+ be faster in some scenarios.
+ repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
+ When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
+ applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
+
+ Examples:
+
+ ```python
+ >>> from transformers import ModernBertModel, ModernBertConfig
+
+ >>> # Initializing a ModernBert style configuration
+ >>> configuration = ModernBertConfig()
+
+ >>> # Initializing a model from the modernbert-base style configuration
+ >>> model = ModernBertModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "modernbert"
+ attribute_map = {"rope_theta": "global_rope_theta"}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=50368,
+ hidden_size=768,
+ intermediate_size=1152,
+ num_hidden_layers=22,
+ num_attention_heads=12,
+ hidden_activation="gelu",
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ initializer_cutoff_factor=2.0,
+ norm_eps=1e-5,
+ norm_bias=False,
+ pad_token_id=50283,
+ eos_token_id=50282,
+ bos_token_id=50281,
+ cls_token_id=50281,
+ sep_token_id=50282,
+ global_rope_theta=160000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ global_attn_every_n_layers=3,
+ local_attention=128,
+ local_rope_theta=10000.0,
+ embedding_dropout=0.0,
+ mlp_bias=False,
+ mlp_dropout=0.0,
+ decoder_bias=True,
+ classifier_pooling: Literal["cls", "mean"] = "cls",
+ classifier_dropout=0.0,
+ classifier_bias=False,
+ classifier_activation="gelu",
+ deterministic_flash_attn=False,
+ sparse_prediction=False,
+ sparse_pred_ignore_index=-100,
+ reference_compile=None,
+ repad_logits_with_grad=False,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ cls_token_id=cls_token_id,
+ sep_token_id=sep_token_id,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.initializer_range = initializer_range
+ self.initializer_cutoff_factor = initializer_cutoff_factor
+ self.norm_eps = norm_eps
+ self.norm_bias = norm_bias
+ self.global_rope_theta = global_rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.hidden_activation = hidden_activation
+ self.global_attn_every_n_layers = global_attn_every_n_layers
+ self.local_attention = local_attention
+ self.local_rope_theta = local_rope_theta
+ self.embedding_dropout = embedding_dropout
+ self.mlp_bias = mlp_bias
+ self.mlp_dropout = mlp_dropout
+ self.decoder_bias = decoder_bias
+ self.classifier_pooling = classifier_pooling
+ self.classifier_dropout = classifier_dropout
+ self.classifier_bias = classifier_bias
+ self.classifier_activation = classifier_activation
+ self.deterministic_flash_attn = deterministic_flash_attn
+ self.sparse_prediction = sparse_prediction
+ self.sparse_pred_ignore_index = sparse_pred_ignore_index
+ self.reference_compile = reference_compile
+ self.repad_logits_with_grad = repad_logits_with_grad
+
+ if self.classifier_pooling not in ["cls", "mean"]:
+ raise ValueError(
+ f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
+ )
+
+ def to_dict(self):
+ output = super().to_dict()
+ output.pop("reference_compile", None)
+ return output
+
+
+def _unpad_modernbert_input(
+ inputs: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_ids: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ """
+ Remove padding from input sequences.
+
+ Args:
+ inputs: (batch, seqlen, ...) or (batch, seqlen)
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
+ position_ids: (batch, seqlen), int, position ids
+ labels: (batch, seqlen), int, labels
+
+ Returns:
+ unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
+ indices: (total_nnz)
+ cu_seqlens: (batch + 1), the cumulative sequence lengths
+ max_seqlen_in_batch: int
+ unpadded_position_ids: (total_nnz) or None
+ unpadded_labels: (total_nnz) or None
+ """
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
+ cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+
+ if inputs.dim() == 2:
+ unpadded_inputs = inputs.flatten()[indices]
+ else:
+ batch, seqlen, *rest = inputs.shape
+ shape = batch * seqlen
+ unpadded_inputs = inputs.view(shape, *rest)[indices]
+
+ unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
+ unpadded_labels = labels.flatten()[indices] if labels is not None else None
+
+ return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
+
+
+def _pad_modernbert_output(
+ inputs: torch.Tensor,
+ indices: torch.Tensor,
+ batch: int,
+ seqlen: int,
+) -> torch.Tensor:
+ """
+ Add padding to sequences.
+
+ Args:
+ inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
+ indices: (total_nnz)
+ batch: int, batch size
+ seqlen: int, max sequence length
+
+ Returns:
+ padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
+ """
+ if inputs.dim() == 1:
+ output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
+ output[indices] = inputs
+ padded_inputs = output.view(batch, seqlen)
+ else:
+ _, *rest = inputs.shape
+ output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
+ output[indices] = inputs
+ padded_inputs = output.view(batch, seqlen, *rest)
+
+ return padded_inputs
+
+
+class ApplyRotaryEmbUnpad(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ qkv,
+ cos,
+ sin,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ ):
+ # (total_nnz, 3, nheads, headdim)
+ qkv = qkv.contiguous()
+ total_nnz, _three, _nheads, headdim = qkv.shape
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
+ # we get the same tensor
+ # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
+ qk = qkv[:, :2].view(total_nnz, -1, headdim)
+ apply_rotary(
+ qk,
+ cos,
+ sin,
+ seqlen_offsets=0,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ interleaved=False,
+ inplace=True,
+ )
+
+ ctx.save_for_backward(cos, sin, cu_seqlens)
+ ctx.max_seqlen = max_seqlen
+ return qkv
+
+ @staticmethod
+ def backward(ctx, do):
+ cos, sin, cu_seqlens = ctx.saved_tensors
+ do = do.contiguous()
+ total_nnz, _three, _nheads, headdim = do.shape
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
+ # we get the same tensor
+ dqk = do[:, :2].view(total_nnz, -1, headdim)
+ apply_rotary(
+ dqk,
+ cos,
+ sin,
+ seqlen_offsets=0,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=ctx.max_seqlen,
+ interleaved=False,
+ inplace=True,
+ conjugate=True,
+ )
+
+ return do, None, None, None, None, None, None
+
+
+def apply_rotary_unpadded(
+ qkv,
+ cos,
+ sin,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+):
+ """
+ Arguments:
+ qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
+ of 1st half and 2nd half (GPT-NeoX style).
+ inplace: if True, apply rotary embedding in-place.
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
+ Most commonly used in inference when we have KV cache.
+ cu_seqlens: (batch + 1,) or None
+ max_seqlen: int
+ Return:
+ out: (total_nnz, dim)
+ rotary_dim must be <= headdim
+ Apply rotary embedding to the first rotary_dim of x.
+ """
+ return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
+
+
+class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
+ """
+ The rotary position embeddings applied directly to unpadded sequences.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ base: float = 10000.0,
+ max_seqlen: Optional[int] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
+ up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
+ the cos_sin_cache will be recomputed during the forward pass.
+ """
+ super().__init__(dim=dim, base=base, device=device, interleaved=False)
+ self.max_seqlen = max_seqlen
+
+ if max_seqlen is not None and device is not None and dtype is not None:
+ self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
+
+ def forward(
+ self,
+ qkv: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ max_seqlen: Optional[int] = None,
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Apply rotary embedding *inplace* to qkv.
+ qkv: (total_nnz, 3, nheads, headdim)
+ cu_seqlens: (batch + 1,) cumulative sequence lengths
+ max_seqlen: int max seq length in the batch
+ """
+ if max_seqlen is not None:
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
+
+ qkv = apply_rotary_unpadded(
+ qkv,
+ self._cos_cached,
+ self._sin_cached,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+
+ return qkv
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
+
+
+class ModernBertEmbeddings(nn.Module):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+ """
+
+ def __init__(self, config: ModernBertConfig):
+ super().__init__()
+ self.config = config
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ self.drop = nn.Dropout(config.embedding_dropout)
+
+ @torch.compile(dynamic=True)
+ def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
+ return self.drop(self.norm(self.tok_embeddings(input_ids)))
+
+ def forward(
+ self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ if inputs_embeds is not None:
+ hidden_states = self.drop(self.norm(inputs_embeds))
+ else:
+ hidden_states = (
+ self.compiled_embeddings(input_ids)
+ if self.config.reference_compile
+ else self.drop(self.norm(self.tok_embeddings(input_ids)))
+ )
+ return hidden_states
+
+
+class ModernBertMLP(nn.Module):
+ """Applies the GLU at the end of each ModernBERT layer.
+
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
+ """
+
+ def __init__(self, config: ModernBertConfig):
+ super().__init__()
+ self.config = config
+ self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias)
+ self.act = ACT2FN[config.hidden_activation]
+ self.drop = nn.Dropout(config.mlp_dropout)
+ self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
+ return self.Wo(self.drop(self.act(input) * gate))
+
+
+class ModernBertRotaryEmbedding(GemmaRotaryEmbedding):
+ pass
+
+
+def eager_attention_forward(
+ module: "ModernBertAttention",
+ qkv: torch.Tensor,
+ attention_mask: torch.Tensor,
+ sliding_window_mask: torch.Tensor,
+ position_ids: Optional[torch.LongTensor],
+ local_attention: tuple[int, int],
+ bs: int,
+ dim: int,
+ output_attentions: Optional[bool] = False,
+ **_kwargs,
+) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
+ # qkv: [batch_size, seqlen, 3, nheads, headdim]
+ cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
+ query, key, value = qkv.transpose(3, 1).unbind(dim=2)
+ # query, key, value: [batch_size, heads, seq_len, head_dim]
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
+
+ scale = module.head_dim**-0.5
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
+
+ if local_attention != (-1, -1):
+ attention_mask = sliding_window_mask
+
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bs, -1, dim)
+ if output_attentions:
+ return (attn_output, attn_weights)
+ return (attn_output,)
+
+
+def flash_attention_forward(
+ module: "ModernBertAttention",
+ qkv: torch.Tensor,
+ rotary_emb: ModernBertUnpaddedRotaryEmbedding,
+ cu_seqlens: torch.Tensor,
+ max_seqlen: int,
+ local_attention: tuple[int, int],
+ bs: int,
+ dim: int,
+ target_dtype: torch.dtype = torch.bfloat16,
+ **_kwargs,
+) -> tuple[torch.Tensor]:
+ # (total_seqlen, 3, nheads, headdim)
+ qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
+
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
+ if convert_dtype:
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
+ orig_dtype = qkv.dtype
+ qkv = qkv.to(target_dtype)
+
+ attn = flash_attn_varlen_qkvpacked_func(
+ qkv,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ dropout_p=module.attention_dropout if module.training else 0.0,
+ deterministic=module.deterministic_flash_attn,
+ window_size=local_attention,
+ )
+ attn = attn.to(orig_dtype) # type: ignore
+ else:
+ attn = flash_attn_varlen_qkvpacked_func(
+ qkv,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ dropout_p=module.attention_dropout if module.training else 0.0,
+ deterministic=module.deterministic_flash_attn,
+ window_size=local_attention,
+ )
+ return (attn.view(bs, dim),)
+
+
+def sdpa_attention_forward(
+ module: "ModernBertAttention",
+ qkv: torch.Tensor,
+ attention_mask: torch.Tensor,
+ sliding_window_mask: torch.Tensor,
+ position_ids: Optional[torch.LongTensor],
+ local_attention: tuple[int, int],
+ bs: int,
+ dim: int,
+ **_kwargs,
+) -> tuple[torch.Tensor]:
+ # qkv: [batch_size, seqlen, 3, nheads, headdim]
+ cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
+ query, key, value = qkv.transpose(3, 1).unbind(dim=2)
+ # query, key, value: [batch_size, heads, seq_len, head_dim]
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
+
+ if local_attention != (-1, -1):
+ attention_mask = sliding_window_mask
+
+ attn_output = (
+ F.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ dropout_p=module.attention_dropout if module.training else 0.0,
+ attn_mask=attention_mask,
+ )
+ .transpose(1, 2)
+ .contiguous()
+ )
+ attn_output = attn_output.view(bs, -1, dim)
+ return (attn_output,)
+
+
+MODERNBERT_ATTENTION_FUNCTION = {
+ "flash_attention_2": flash_attention_forward,
+ "eager": eager_attention_forward,
+ "sdpa": sdpa_attention_forward,
+}
+
+
+class ModernBertAttention(nn.Module):
+ """Performs multi-headed self attention on a batch of unpadded sequences.
+
+ If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
+ If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
+ which requires padding and unpadding inputs, adding some overhead.
+
+ See `forward` method for additional details.
+ """
+
+ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_id = layer_id
+
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.deterministic_flash_attn = config.deterministic_flash_attn
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.hidden_size // config.num_attention_heads
+ self.all_head_size = self.head_dim * self.num_heads
+ self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
+
+ if layer_id % config.global_attn_every_n_layers != 0:
+ self.local_attention = (config.local_attention // 2, config.local_attention // 2)
+ rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta
+ max_position_embeddings = config.local_attention
+ else:
+ self.local_attention = (-1, -1)
+ max_position_embeddings = config.max_position_embeddings
+ rope_theta = config.global_rope_theta
+
+ if config._attn_implementation == "flash_attention_2":
+ self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
+ dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
+ )
+ else:
+ config_copy = copy.deepcopy(config)
+ config_copy.rope_theta = rope_theta
+ self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
+
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
+ self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
+ self.pruned_heads = set()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ **kwargs,
+ ) -> torch.Tensor:
+ qkv = self.Wqkv(hidden_states)
+
+ bs = hidden_states.shape[0]
+ if self.config._attn_implementation == "flash_attention_2":
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
+ else:
+ qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
+
+ attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
+ self,
+ qkv=qkv,
+ rotary_emb=self.rotary_emb,
+ local_attention=self.local_attention,
+ bs=bs,
+ dim=self.all_head_size,
+ output_attentions=output_attentions,
+ **kwargs,
+ )
+ hidden_states = attn_outputs[0]
+ hidden_states = self.out_drop(self.Wo(hidden_states))
+
+ return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
+
+
+class ModernBertEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ if layer_id == 0:
+ self.attn_norm = nn.Identity()
+ else:
+ self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ self.attn = ModernBertAttention(config=config, layer_id=layer_id)
+ self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ self.mlp = ModernBertMLP(config)
+
+ @torch.compile(dynamic=True)
+ def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.mlp(self.mlp_norm(hidden_states))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> torch.Tensor:
+ attn_outputs = self.attn(
+ self.attn_norm(hidden_states),
+ attention_mask=attention_mask,
+ sliding_window_mask=sliding_window_mask,
+ position_ids=position_ids,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ output_attentions=output_attentions,
+ )
+ hidden_states = hidden_states + attn_outputs[0]
+ mlp_output = (
+ self.compiled_mlp(hidden_states)
+ if self.config.reference_compile
+ else self.mlp(self.mlp_norm(hidden_states))
+ )
+ hidden_states = hidden_states + mlp_output
+
+ return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
+
+
+@auto_docstring
+class ModernBertPreTrainedModel(PreTrainedModel):
+ config: ModernBertConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = False
+
+ def _init_weights(self, module: nn.Module):
+ cutoff_factor = self.config.initializer_cutoff_factor
+ if cutoff_factor is None:
+ cutoff_factor = 3
+
+ def init_weight(module: nn.Module, std: float):
+ nn.init.trunc_normal_(
+ module.weight,
+ mean=0.0,
+ std=std,
+ a=-cutoff_factor * std,
+ b=cutoff_factor * std,
+ )
+
+ if isinstance(module, nn.Linear):
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ stds = {
+ "in": self.config.initializer_range,
+ "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
+ "embedding": self.config.initializer_range,
+ "final_out": self.config.hidden_size**-0.5,
+ }
+
+ if isinstance(module, ModernBertEmbeddings):
+ init_weight(module.tok_embeddings, stds["embedding"])
+ elif isinstance(module, ModernBertMLP):
+ init_weight(module.Wi, stds["in"])
+ init_weight(module.Wo, stds["out"])
+ elif isinstance(module, ModernBertAttention):
+ init_weight(module.Wqkv, stds["in"])
+ init_weight(module.Wo, stds["out"])
+ elif isinstance(module, ModernBertPredictionHead):
+ init_weight(module.dense, stds["out"])
+ elif isinstance(module, ModernBertForMaskedLM):
+ init_weight(module.decoder, stds["out"])
+ elif isinstance(
+ module,
+ (
+ ModernBertForSequenceClassification,
+ ModernBertForMultipleChoice,
+ ModernBertForTokenClassification,
+ ModernBertForQuestionAnswering,
+ ),
+ ):
+ init_weight(module.classifier, stds["final_out"])
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ def _check_and_adjust_attn_implementation(
+ self, attn_implementation: Optional[str], is_init_check: bool = False
+ ) -> str:
+ """
+ Checks and dispatches to hhe requested attention implementation.
+ """
+ # If the user didn't specify anything, try to use flash_attention_2 if available.
+ # Otherwise we fall back to the default SDPA -> Eager from the super() method.
+ # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
+ # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
+
+ try:
+ attn_implementation = (
+ "flash_attention_2"
+ if attn_implementation is None and self._flash_attn_2_can_dispatch()
+ else attn_implementation
+ )
+ except (ValueError, ImportError):
+ pass
+ return super()._check_and_adjust_attn_implementation(
+ attn_implementation=attn_implementation, is_init_check=is_init_check
+ )
+
+ def _maybe_set_compile(self):
+ if self.config.reference_compile is False:
+ return
+
+ if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
+ if self.config.reference_compile:
+ logger.warning_once(
+ "If `accelerate` split the model across devices, `torch.compile` will not work. "
+ "Falling back to non-compiled mode."
+ )
+ self.config.reference_compile = False
+
+ if self.device.type == "mps":
+ if self.config.reference_compile:
+ logger.warning_once(
+ "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
+ "Falling back to non-compiled mode."
+ )
+ self.config.reference_compile = False
+
+ if self.device.type == "cpu":
+ if self.config.reference_compile:
+ logger.warning_once(
+ "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
+ "Falling back to non-compiled mode."
+ )
+ self.config.reference_compile = False
+
+ if self.config.reference_compile is None:
+ self.config.reference_compile = is_triton_available()
+
+ def resize_token_embeddings(self, *args, **kwargs):
+ model_embeds = super().resize_token_embeddings(*args, **kwargs)
+
+ if self.config.reference_compile in {True, None}:
+ if self.config.reference_compile:
+ logger.warning_once(
+ "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
+ )
+ self.config.reference_compile = False
+
+ return model_embeds
+
+
+@auto_docstring
+class ModernBertModel(ModernBertPreTrainedModel):
+ def __init__(self, config: ModernBertConfig):
+ super().__init__(config)
+ self.config = config
+ self.embeddings = ModernBertEmbeddings(config)
+ self.layers = nn.ModuleList(
+ [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
+ )
+ self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ self.gradient_checkpointing = False
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.tok_embeddings = value
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ indices: Optional[torch.Tensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ batch_size: Optional[int] = None,
+ seq_len: Optional[int] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
+ r"""
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
+ far-away tokens in the local attention layers when not using Flash Attention.
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
+ max_seqlen (`int`, *optional*):
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
+ batch_size (`int`, *optional*):
+ Batch size of the input sequences. Used to pad the output tensors.
+ seq_len (`int`, *optional*):
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ self._maybe_set_compile()
+
+ if input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+
+ if batch_size is None and seq_len is None:
+ if inputs_embeds is not None:
+ batch_size, seq_len = inputs_embeds.shape[:2]
+ else:
+ batch_size, seq_len = input_ids.shape[:2]
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
+
+ repad = False
+ if self.config._attn_implementation == "flash_attention_2":
+ if indices is None and cu_seqlens is None and max_seqlen is None:
+ repad = True
+ if inputs_embeds is None:
+ with torch.no_grad():
+ input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
+ inputs=input_ids, attention_mask=attention_mask
+ )
+ else:
+ inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
+ inputs=inputs_embeds, attention_mask=attention_mask
+ )
+ else:
+ if position_ids is None:
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
+
+ attention_mask, sliding_window_mask = self._update_attention_mask(
+ attention_mask, output_attentions=output_attentions
+ )
+
+ hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
+
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ sliding_window_mask=sliding_window_mask,
+ position_ids=position_ids,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+ if output_attentions and len(layer_outputs) > 1:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states = self.final_norm(hidden_states)
+
+ if repad:
+ hidden_states = _pad_modernbert_output(
+ inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
+ )
+ if all_hidden_states is not None:
+ all_hidden_states = tuple(
+ _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
+ for hs in all_hidden_states
+ )
+ # If the attention implementation is FA2 and there is no need for repadding, there might still be the batch
+ # dimension missing
+ elif (
+ self.config._attn_implementation == "flash_attention_2"
+ and all_hidden_states is not None
+ and all_hidden_states[-1].dim() == 2
+ ):
+ hidden_states = hidden_states.unsqueeze(0)
+ all_hidden_states = tuple(hs.unsqueeze(0) for hs in all_hidden_states)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
+ if output_attentions:
+ if self.config._attn_implementation == "sdpa":
+ logger.warning_once(
+ "Outputting attentions is only supported with the 'eager' attention implementation, "
+ 'not with "sdpa". Falling back to `attn_implementation="eager"`.'
+ )
+ self.config._attn_implementation = "eager"
+ elif self.config._attn_implementation != "eager":
+ logger.warning_once(
+ "Outputting attentions is only supported with the eager attention implementation, "
+ f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
+ " Setting `output_attentions=False`."
+ )
+
+ global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
+
+ # Create position indices
+ rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
+ # Calculate distance between positions
+ distance = torch.abs(rows - rows.T)
+
+ # Create sliding window mask (1 for positions within window, 0 outside)
+ window_mask = (
+ (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
+ )
+ # Combine with existing mask
+ sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
+
+ return global_attention_mask, sliding_window_mask
+
+
+class ModernBertPredictionHead(nn.Module):
+ def __init__(self, config: ModernBertConfig):
+ super().__init__()
+ self.config = config
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
+ self.act = ACT2FN[config.classifier_activation]
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.norm(self.act(self.dense(hidden_states)))
+
+
+@auto_docstring(
+ custom_intro="""
+ The ModernBert Model with a decoder head on top that is used for masked language modeling.
+ """
+)
+class ModernBertForMaskedLM(ModernBertPreTrainedModel):
+ _tied_weights_keys = ["decoder.weight"]
+
+ def __init__(self, config: ModernBertConfig):
+ super().__init__(config)
+ self.config = config
+ self.model = ModernBertModel(config)
+ self.head = ModernBertPredictionHead(config)
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
+
+ self.sparse_prediction = self.config.sparse_prediction
+ self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.decoder
+
+ def set_output_embeddings(self, new_embeddings: nn.Linear):
+ self.decoder = new_embeddings
+
+ @torch.compile(dynamic=True)
+ def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
+ return self.decoder(self.head(output))
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ indices: Optional[torch.Tensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ batch_size: Optional[int] = None,
+ seq_len: Optional[int] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
+ r"""
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
+ far-away tokens in the local attention layers when not using Flash Attention.
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
+ max_seqlen (`int`, *optional*):
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
+ batch_size (`int`, *optional*):
+ Batch size of the input sequences. Used to pad the output tensors.
+ seq_len (`int`, *optional*):
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ self._maybe_set_compile()
+
+ if self.config._attn_implementation == "flash_attention_2":
+ if indices is None and cu_seqlens is None and max_seqlen is None:
+ if batch_size is None and seq_len is None:
+ if inputs_embeds is not None:
+ batch_size, seq_len = inputs_embeds.shape[:2]
+ else:
+ batch_size, seq_len = input_ids.shape[:2]
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
+
+ if inputs_embeds is None:
+ with torch.no_grad():
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
+ inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
+ )
+ else:
+ inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
+ inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ sliding_window_mask=sliding_window_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ indices=indices,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ batch_size=batch_size,
+ seq_len=seq_len,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = outputs[0]
+
+ if self.sparse_prediction and labels is not None:
+ # flatten labels and output first
+ labels = labels.view(-1)
+ last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
+
+ # then filter out the non-masked tokens
+ mask_tokens = labels != self.sparse_pred_ignore_index
+ last_hidden_state = last_hidden_state[mask_tokens]
+ labels = labels[mask_tokens]
+
+ logits = (
+ self.compiled_head(last_hidden_state)
+ if self.config.reference_compile
+ else self.decoder(self.head(last_hidden_state))
+ )
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ if self.config._attn_implementation == "flash_attention_2":
+ # Logits padding
+ with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
+ logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
+ # Hidden states padding
+ if getattr(outputs, "hidden_states", None) is not None:
+ padded_hidden_states = []
+ for hs in outputs.hidden_states:
+ if hs.dim() == 3 and hs.shape[0] == 1:
+ hs = hs.squeeze(0)
+ padded_hidden_states.append(
+ _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
+ )
+ outputs.hidden_states = tuple(padded_hidden_states)
+
+ if not return_dict:
+ output = (logits,)
+ return ((loss,) + output) if loss is not None else output
+
+ return MaskedLMOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The ModernBert Model with a sequence classification head on top that performs pooling.
+ """
+)
+class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
+ def __init__(self, config: ModernBertConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.model = ModernBertModel(config)
+ self.head = ModernBertPredictionHead(config)
+ self.drop = torch.nn.Dropout(config.classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ indices: Optional[torch.Tensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ batch_size: Optional[int] = None,
+ seq_len: Optional[int] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
+ r"""
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
+ far-away tokens in the local attention layers when not using Flash Attention.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
+ max_seqlen (`int`, *optional*):
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
+ batch_size (`int`, *optional*):
+ Batch size of the input sequences. Used to pad the output tensors.
+ seq_len (`int`, *optional*):
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ self._maybe_set_compile()
+
+ if input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+
+ if batch_size is None and seq_len is None:
+ if inputs_embeds is not None:
+ batch_size, seq_len = inputs_embeds.shape[:2]
+ else:
+ batch_size, seq_len = input_ids.shape[:2]
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ sliding_window_mask=sliding_window_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ indices=indices,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ batch_size=batch_size,
+ seq_len=seq_len,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = outputs[0]
+
+ if self.config.classifier_pooling == "cls":
+ last_hidden_state = last_hidden_state[:, 0]
+ elif self.config.classifier_pooling == "mean":
+ last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
+ dim=1, keepdim=True
+ )
+
+ pooled_output = self.head(last_hidden_state)
+ pooled_output = self.drop(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,)
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
+ """
+)
+class ModernBertForTokenClassification(ModernBertPreTrainedModel):
+ def __init__(self, config: ModernBertConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.model = ModernBertModel(config)
+ self.head = ModernBertPredictionHead(config)
+ self.drop = torch.nn.Dropout(config.classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ indices: Optional[torch.Tensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ batch_size: Optional[int] = None,
+ seq_len: Optional[int] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
+ far-away tokens in the local attention layers when not using Flash Attention.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
+ max_seqlen (`int`, *optional*):
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
+ batch_size (`int`, *optional*):
+ Batch size of the input sequences. Used to pad the output tensors.
+ seq_len (`int`, *optional*):
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ self._maybe_set_compile()
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ sliding_window_mask=sliding_window_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ indices=indices,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ batch_size=batch_size,
+ seq_len=seq_len,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = outputs[0]
+
+ last_hidden_state = self.head(last_hidden_state)
+ last_hidden_state = self.drop(last_hidden_state)
+ logits = self.classifier(last_hidden_state)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
+ def __init__(self, config: ModernBertConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.model = ModernBertModel(config)
+ self.head = ModernBertPredictionHead(config)
+ self.drop = torch.nn.Dropout(config.classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ indices: Optional[torch.Tensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ batch_size: Optional[int] = None,
+ seq_len: Optional[int] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
+ r"""
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
+ far-away tokens in the local attention layers when not using Flash Attention.
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
+ max_seqlen (`int`, *optional*):
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
+ batch_size (`int`, *optional*):
+ Batch size of the input sequences. Used to pad the output tensors.
+ seq_len (`int`, *optional*):
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ self._maybe_set_compile()
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ sliding_window_mask=sliding_window_mask,
+ position_ids=position_ids,
+ indices=indices,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ batch_size=batch_size,
+ seq_len=seq_len,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = outputs[0]
+
+ last_hidden_state = self.head(last_hidden_state)
+ last_hidden_state = self.drop(last_hidden_state)
+ logits = self.classifier(last_hidden_state)
+
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ loss = None
+ if start_positions is not None and end_positions is not None:
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
+ """
+)
+class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
+ def __init__(self, config: ModernBertConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.model = ModernBertModel(config)
+ self.head = ModernBertPredictionHead(config)
+ self.drop = torch.nn.Dropout(config.classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ indices: Optional[torch.Tensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ batch_size: Optional[int] = None,
+ seq_len: Optional[int] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
+ r"""
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
+ far-away tokens in the local attention layers when not using Flash Attention.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
+ max_seqlen (`int`, *optional*):
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
+ batch_size (`int`, *optional*):
+ Batch size of the input sequences. Used to pad the output tensors.
+ seq_len (`int`, *optional*):
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ self._maybe_set_compile()
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ sliding_window_mask=sliding_window_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ indices=indices,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ batch_size=batch_size,
+ seq_len=seq_len,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size)
+
+ # If classifier_pooling is "cls", isolate the token
+ if self.config.classifier_pooling == "cls":
+ indices_0 = torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device)
+ # for left or right padding, is the first non-pad token
+ if attention_mask is not None:
+ cls_mask = attention_mask.argmax(dim=-1).to(last_hidden_state.device)
+ # if no pad, is the first token
+ else:
+ cls_mask = torch.tensor(0, dtype=torch.long, device=last_hidden_state.device)
+ # extract the token for the logits
+ last_hidden_state = last_hidden_state[indices_0, cls_mask]
+
+ # If classifier_pooling is "mean", pool the hidden states by averaging over the sequence length
+ elif self.config.classifier_pooling == "mean":
+ num_non_pad_tokens = attention_mask.sum(dim=1, keepdim=True)
+ last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / num_non_pad_tokens
+
+ pooled_output = self.head(last_hidden_state)
+ pooled_output = self.drop(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = nn.CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "ModernBertConfig",
+ "ModernBertModel",
+ "ModernBertPreTrainedModel",
+ "ModernBertForMaskedLM",
+ "ModernBertForSequenceClassification",
+ "ModernBertForTokenClassification",
+ "ModernBertForQuestionAnswering",
+ "ModernBertForMultipleChoice",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert_decoder/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert_decoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f29efca0bf52a68509d3b58ad3ca456f216a042
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert_decoder/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_modernbert_decoder import *
+ from .modeling_modernbert_decoder import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert_decoder/configuration_modernbert_decoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert_decoder/configuration_modernbert_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a065a53dfc4b27d1f82cd0385c0c67636018c63
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert_decoder/configuration_modernbert_decoder.py
@@ -0,0 +1,208 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_modernbert_decoder.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 Johns Hopkins University, LightOn, and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+
+
+class ModernBertDecoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ModernBertDecoderModel`]. It is used to instantiate a ModernBert
+ decoder model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the ModernBERT-base decoder.
+ e.g. [blab-jhu/test-32m-dec](https://huggingface.co/blab-jhu/test-32m-dec)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50368):
+ Vocabulary size of the ModernBert decoder model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ModernBertDecoderModel`]
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 1152):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 22):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
+ if not specified.
+ max_position_embeddings (`int`, *optional*, defaults to 8192):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
+ The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
+ norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ norm_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in the normalization layers.
+ pad_token_id (`int`, *optional*, defaults to 50283):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 50282):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 50281):
+ Beginning of stream token id.
+ cls_token_id (`int`, *optional*, defaults to 50281):
+ Classification token id.
+ sep_token_id (`int`, *optional*, defaults to 50282):
+ Separation token id.
+ global_rope_theta (`float`, *optional*, defaults to 160000.0):
+ The base period of the global RoPE embeddings.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ embedding_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the embeddings.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in the MLP layers.
+ mlp_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the MLP layers.
+ decoder_bias (`bool`, *optional*, defaults to `True`):
+ Whether to use bias in the decoder layers.
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the classifier.
+ classifier_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in the classifier.
+ classifier_activation (`str`, *optional*, defaults to `"gelu"`):
+ The activation function for the classifier.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ local_attention (`int`, *optional*, defaults to 128):
+ The sliding window size for local attention. Only used for layers that use local attention. Note that for
+ the decoder to match ModernBERT this is actually half of the sliding window size, so 128 => 64.
+ global_attn_every_n_layers (`int`, *optional*, defaults to 3):
+ Every `global_attn_every_n_layers` layers will use global attention instead of local attention.
+ local_rope_theta (`float`, *optional*, defaults to 160000.0):
+ The base period of the local RoPE embeddings. If not specified, defaults to 160000.0
+ layer_types (`list`, *optional*):
+ List of layer types, one for each layer. If not specified, will be automatically generated based on
+ `global_attn_every_n_layers`. Should contain "full_attention" or "sliding_attention".
+
+ Examples:
+
+ ```python
+ >>> from transformers import ModernBertDecoderModel, ModernBertDecoderConfig
+
+ >>> # Initializing a ModernBert decoder style configuration
+ >>> configuration = ModernBertDecoderConfig()
+
+ >>> # Initializing a model from the modernbert-base decoder style configuration
+ >>> model = ModernBertDecoderModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "modernbert-decoder"
+ attribute_map = {"rope_theta": "global_rope_theta"}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=50368,
+ hidden_size=768,
+ intermediate_size=1152,
+ num_hidden_layers=22,
+ num_attention_heads=12,
+ hidden_activation="gelu",
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ initializer_cutoff_factor=2.0,
+ norm_eps=1e-5,
+ norm_bias=False,
+ pad_token_id=50283,
+ eos_token_id=50282,
+ bos_token_id=50281,
+ cls_token_id=50281,
+ sep_token_id=50282,
+ global_rope_theta=160000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ embedding_dropout=0.0,
+ mlp_bias=False,
+ mlp_dropout=0.0,
+ decoder_bias=True,
+ classifier_dropout=0.0,
+ classifier_bias=False,
+ classifier_activation="gelu",
+ use_cache=True,
+ local_attention=128,
+ global_attn_every_n_layers=3,
+ local_rope_theta=160000.0,
+ layer_types=None,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ cls_token_id=cls_token_id,
+ sep_token_id=sep_token_id,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.initializer_range = initializer_range
+ self.initializer_cutoff_factor = initializer_cutoff_factor
+ self.norm_eps = norm_eps
+ self.norm_bias = norm_bias
+ self.global_rope_theta = global_rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.hidden_activation = hidden_activation
+ self.embedding_dropout = embedding_dropout
+ self.mlp_bias = mlp_bias
+ self.mlp_dropout = mlp_dropout
+ self.decoder_bias = decoder_bias
+ self.classifier_dropout = classifier_dropout
+ self.classifier_bias = classifier_bias
+ self.classifier_activation = classifier_activation
+ self.use_cache = use_cache
+ self.global_attn_every_n_layers = global_attn_every_n_layers
+ self.local_rope_theta = local_rope_theta
+ # for consistency with ModernBert
+ self.reference_compile = False
+
+ # Set up layer_types for standardized layer type detection
+ self.layer_types = layer_types
+ if self.layer_types is None:
+ # Create layer_types based on the alternating pattern
+ self.layer_types = []
+ for layer_id in range(num_hidden_layers):
+ if layer_id % global_attn_every_n_layers != 0:
+ self.layer_types.append("sliding_attention")
+ else:
+ self.layer_types.append("full_attention")
+
+ # NOTE: sliding window numbers matches ModernBERT but is only half of it
+ self.sliding_window = local_attention // 2 if local_attention else -1
+
+
+__all__ = ["ModernBertDecoderConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..140240cba6f744b3a27886c9f9aab35659d3b727
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py
@@ -0,0 +1,733 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_modernbert_decoder.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 Johns Hopkins University, LightOn, and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections.abc import Callable
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_modernbert_decoder import ModernBertDecoderConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class ModernBertDecoderEmbeddings(nn.Module):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+ """
+
+ def __init__(self, config: ModernBertDecoderConfig):
+ super().__init__()
+ self.config = config
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ self.drop = nn.Dropout(config.embedding_dropout)
+
+ @torch.compile(dynamic=True)
+ def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
+ return self.drop(self.norm(self.tok_embeddings(input_ids)))
+
+ def forward(
+ self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ if inputs_embeds is not None:
+ hidden_states = self.drop(self.norm(inputs_embeds))
+ else:
+ hidden_states = (
+ self.compiled_embeddings(input_ids)
+ if self.config.reference_compile
+ else self.drop(self.norm(self.tok_embeddings(input_ids)))
+ )
+ return hidden_states
+
+
+class ModernBertDecoderMLP(nn.Module):
+ """Applies the GLU at the end of each ModernBertDecoder layer.
+
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
+ """
+
+ def __init__(self, config: ModernBertDecoderConfig):
+ super().__init__()
+ self.config = config
+ self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias)
+ self.act = ACT2FN[config.hidden_activation]
+ self.drop = nn.Dropout(config.mlp_dropout)
+ self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
+ return self.Wo(self.drop(self.act(input) * gate))
+
+
+class ModernBertDecoderRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: ModernBertDecoderConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def eager_attention_forward(
+ module: "ModernBertDecoderAttention",
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ dropout: float = 0.0,
+ scaling: Optional[float] = None,
+ sliding_window: Optional[int] = None,
+ **kwargs,
+) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """A simple eager attention implementation for ModernBERT decoder."""
+ if scaling is None:
+ scaling = module.head_dim**-0.5
+
+ # Compute attention scores
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+
+ # Use the pre-computed attention mask
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, attn_weights
+
+
+class ModernBertDecoderAttention(nn.Module):
+ """Performs causal multi-headed self attention for ModernBERT decoder.
+
+ It supports both local attention (sliding window) and global attention patterns.
+ """
+
+ def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = config.hidden_size // config.num_attention_heads
+ self.num_heads = config.num_attention_heads
+ self.all_head_size = self.head_dim * self.num_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = self.config.attention_dropout
+ self.is_causal = True
+
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
+ )
+
+ # NOTE: this is different than ModernBERT (separated QKV) so be sure to adapt to this
+ self.q_proj = nn.Linear(self.config.hidden_size, self.all_head_size, bias=self.config.attention_bias)
+ self.k_proj = nn.Linear(self.config.hidden_size, self.all_head_size, bias=self.config.attention_bias)
+ self.v_proj = nn.Linear(self.config.hidden_size, self.all_head_size, bias=self.config.attention_bias)
+
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
+ self.out_drop = nn.Dropout(config.attention_dropout)
+
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=self.attention_dropout if self.training else 0.0,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.out_drop(self.Wo(attn_output))
+ return attn_output, attn_weights
+
+
+class ModernBertDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.attention_type = config.layer_types[layer_idx]
+ self.attn_norm = (
+ nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ if layer_idx != 0
+ else nn.Identity()
+ )
+ self.attn = ModernBertDecoderAttention(config=config, layer_idx=layer_idx)
+ self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ self.mlp = ModernBertDecoderMLP(config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings_global: torch.Tensor,
+ position_embeddings_local: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+ hidden_states = self.attn_norm(hidden_states)
+
+ # apply global RoPE to non-sliding layer only
+ if self.attn.is_sliding:
+ position_embeddings = position_embeddings_local
+ else:
+ position_embeddings = position_embeddings_global
+
+ # Self Attention
+ attn_outputs = self.attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = attn_outputs[0]
+
+ # Add residual connection
+ hidden_states = residual + hidden_states
+
+ # MLP
+ residual = hidden_states
+ hidden_states = self.mlp_norm(hidden_states)
+ mlp_output = self.mlp(hidden_states)
+ hidden_states = residual + mlp_output
+ return hidden_states
+
+
+class ModernBertDecoderPredictionHead(nn.Module):
+ def __init__(self, config: ModernBertDecoderConfig):
+ super().__init__()
+ self.config = config
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
+ self.act = ACT2FN[config.classifier_activation]
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.norm(self.act(self.dense(hidden_states)))
+
+
+@auto_docstring
+class ModernBertDecoderPreTrainedModel(PreTrainedModel):
+ config: ModernBertDecoderConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["ModernBertDecoderLayer"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": ModernBertDecoderLayer,
+ "attentions": ModernBertDecoderAttention,
+ }
+
+ def _init_weights(self, module: nn.Module):
+ cutoff_factor = self.config.initializer_cutoff_factor
+ if cutoff_factor is None:
+ cutoff_factor = 3
+
+ def init_weight(module: nn.Module, std: float):
+ nn.init.trunc_normal_(
+ module.weight,
+ mean=0.0,
+ std=std,
+ a=-cutoff_factor * std,
+ b=cutoff_factor * std,
+ )
+
+ if isinstance(module, nn.Linear):
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ stds = {
+ "in": self.config.initializer_range,
+ "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
+ "embedding": self.config.initializer_range,
+ "final_out": self.config.hidden_size**-0.5,
+ }
+
+ if isinstance(module, ModernBertDecoderEmbeddings):
+ init_weight(module.tok_embeddings, stds["embedding"])
+ elif isinstance(module, ModernBertDecoderMLP):
+ init_weight(module.Wi, stds["in"])
+ init_weight(module.Wo, stds["out"])
+ elif isinstance(module, ModernBertDecoderAttention):
+ init_weight(module.q_proj, stds["in"])
+ init_weight(module.k_proj, stds["in"])
+ init_weight(module.v_proj, stds["in"])
+ init_weight(module.Wo, stds["out"])
+ elif isinstance(module, ModernBertDecoderPredictionHead):
+ init_weight(module.dense, stds["out"])
+ elif isinstance(module, ModernBertDecoderForSequenceClassification):
+ init_weight(module.classifier, stds["final_out"])
+ elif isinstance(module, ModernBertDecoderForCausalLM):
+ init_weight(module.decoder, stds["out"])
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+@auto_docstring
+class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
+ def __init__(self, config: ModernBertDecoderConfig):
+ super().__init__(config)
+ self.config = config
+ self.embeddings = ModernBertDecoderEmbeddings(config)
+ self.layers = nn.ModuleList(
+ [ModernBertDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ self.gradient_checkpointing = False
+
+ self.global_rotary_emb = ModernBertDecoderRotaryEmbedding(config=config)
+ self.local_rotary_emb = ModernBertDecoderRotaryEmbedding(config=config)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.tok_embeddings = value
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
+ if (input_ids is None) == (inputs_embeds is None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ batch_size, seq_length = input_ids.shape[:2]
+ else:
+ batch_size, seq_length = inputs_embeds.shape[:2]
+
+ # Handle past_key_values and cache setup
+ if use_cache and past_key_values is None and not self.training:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + seq_length,
+ device=input_ids.device if input_ids is not None else inputs_embeds.device,
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
+
+ # Calculate embeddings
+ hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ # Prepare mask arguments
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": hidden_states,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings_global = self.global_rotary_emb(hidden_states, position_ids)
+ position_embeddings_local = self.local_rotary_emb(hidden_states, position_ids)
+
+ for idx, decoder_layer in enumerate(self.layers):
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_embeddings_global=position_embeddings_global,
+ position_embeddings_local=position_embeddings_local,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = self.final_norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The ModernBert Decoder Model with a language modeling head on top for causal language modeling (CLM).
+ """
+)
+class ModernBertDecoderForCausalLM(ModernBertDecoderPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["decoder.weight"]
+
+ def __init__(self, config: ModernBertDecoderConfig):
+ super().__init__(config)
+ self.config = config
+ self.model = ModernBertDecoderModel(config)
+ self.lm_head = ModernBertDecoderPredictionHead(config)
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.decoder = new_embeddings
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+ [`~modeling_outputs.CausalLMOutputWithPast`]
+ comprising various elements depending on the configuration and inputs.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, ModernBertDecoderForCausalLM
+
+ >>> model = ModernBertDecoderForCausalLM.from_pretrained("blab-jhu/test-32m-dec")
+ >>> tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
+
+ >>> prompt = "The capital of France is"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=1)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "The capital of France is Paris"
+ ```
+ """
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.decoder(self.lm_head(hidden_states))
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The ModernBert Decoder Model with a sequence classification head on top (linear layer).
+
+ [`ModernBertDecoderForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-1, GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedModel):
+ def __init__(self, config: ModernBertDecoderConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = ModernBertDecoderModel(config)
+
+ self.head = ModernBertDecoderPredictionHead(config)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias)
+ self.drop = torch.nn.Dropout(config.classifier_dropout)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring(checkpoint="blab-jhu/test-32m-dec")
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ **kwargs,
+ )
+ hidden_states = transformer_outputs[0]
+ hidden_states = self.drop(self.head(hidden_states))
+ logits = self.classifier(hidden_states)
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+__all__ = [
+ "ModernBertDecoderModel",
+ "ModernBertDecoderPreTrainedModel",
+ "ModernBertDecoderForCausalLM",
+ "ModernBertDecoderForSequenceClassification",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert_decoder/modular_modernbert_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..05bf9c98d01cec04c75bac108caa06f981eec24e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/modernbert_decoder/modular_modernbert_decoder.py
@@ -0,0 +1,805 @@
+# Copyright 2025 Johns Hopkins University, LightOn, and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections.abc import Callable
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...cache_utils import Cache, DynamicCache
+from ...configuration_utils import PretrainedConfig
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from ..modernbert.modeling_modernbert import (
+ ModernBertEmbeddings,
+ ModernBertMLP,
+ ModernBertPredictionHead,
+ ModernBertPreTrainedModel,
+ ModernBertRotaryEmbedding,
+ apply_rotary_pos_emb,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class ModernBertDecoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ModernBertDecoderModel`]. It is used to instantiate a ModernBert
+ decoder model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the ModernBERT-base decoder.
+ e.g. [blab-jhu/test-32m-dec](https://huggingface.co/blab-jhu/test-32m-dec)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50368):
+ Vocabulary size of the ModernBert decoder model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ModernBertDecoderModel`]
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 1152):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 22):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
+ if not specified.
+ max_position_embeddings (`int`, *optional*, defaults to 8192):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
+ The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
+ norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ norm_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in the normalization layers.
+ pad_token_id (`int`, *optional*, defaults to 50283):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 50282):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 50281):
+ Beginning of stream token id.
+ cls_token_id (`int`, *optional*, defaults to 50281):
+ Classification token id.
+ sep_token_id (`int`, *optional*, defaults to 50282):
+ Separation token id.
+ global_rope_theta (`float`, *optional*, defaults to 160000.0):
+ The base period of the global RoPE embeddings.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ embedding_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the embeddings.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in the MLP layers.
+ mlp_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the MLP layers.
+ decoder_bias (`bool`, *optional*, defaults to `True`):
+ Whether to use bias in the decoder layers.
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the classifier.
+ classifier_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias in the classifier.
+ classifier_activation (`str`, *optional*, defaults to `"gelu"`):
+ The activation function for the classifier.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ local_attention (`int`, *optional*, defaults to 128):
+ The sliding window size for local attention. Only used for layers that use local attention. Note that for
+ the decoder to match ModernBERT this is actually half of the sliding window size, so 128 => 64.
+ global_attn_every_n_layers (`int`, *optional*, defaults to 3):
+ Every `global_attn_every_n_layers` layers will use global attention instead of local attention.
+ local_rope_theta (`float`, *optional*, defaults to 160000.0):
+ The base period of the local RoPE embeddings. If not specified, defaults to 160000.0
+ layer_types (`list`, *optional*):
+ List of layer types, one for each layer. If not specified, will be automatically generated based on
+ `global_attn_every_n_layers`. Should contain "full_attention" or "sliding_attention".
+
+ Examples:
+
+ ```python
+ >>> from transformers import ModernBertDecoderModel, ModernBertDecoderConfig
+
+ >>> # Initializing a ModernBert decoder style configuration
+ >>> configuration = ModernBertDecoderConfig()
+
+ >>> # Initializing a model from the modernbert-base decoder style configuration
+ >>> model = ModernBertDecoderModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "modernbert-decoder"
+ attribute_map = {"rope_theta": "global_rope_theta"}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=50368,
+ hidden_size=768,
+ intermediate_size=1152,
+ num_hidden_layers=22,
+ num_attention_heads=12,
+ hidden_activation="gelu",
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ initializer_cutoff_factor=2.0,
+ norm_eps=1e-5,
+ norm_bias=False,
+ pad_token_id=50283,
+ eos_token_id=50282,
+ bos_token_id=50281,
+ cls_token_id=50281,
+ sep_token_id=50282,
+ global_rope_theta=160000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ embedding_dropout=0.0,
+ mlp_bias=False,
+ mlp_dropout=0.0,
+ decoder_bias=True,
+ classifier_dropout=0.0,
+ classifier_bias=False,
+ classifier_activation="gelu",
+ use_cache=True,
+ local_attention=128,
+ global_attn_every_n_layers=3,
+ local_rope_theta=160000.0,
+ layer_types=None,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ cls_token_id=cls_token_id,
+ sep_token_id=sep_token_id,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.initializer_range = initializer_range
+ self.initializer_cutoff_factor = initializer_cutoff_factor
+ self.norm_eps = norm_eps
+ self.norm_bias = norm_bias
+ self.global_rope_theta = global_rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.hidden_activation = hidden_activation
+ self.embedding_dropout = embedding_dropout
+ self.mlp_bias = mlp_bias
+ self.mlp_dropout = mlp_dropout
+ self.decoder_bias = decoder_bias
+ self.classifier_dropout = classifier_dropout
+ self.classifier_bias = classifier_bias
+ self.classifier_activation = classifier_activation
+ self.use_cache = use_cache
+ self.global_attn_every_n_layers = global_attn_every_n_layers
+ self.local_rope_theta = local_rope_theta
+ # for consistency with ModernBert
+ self.reference_compile = False
+
+ # Set up layer_types for standardized layer type detection
+ self.layer_types = layer_types
+ if self.layer_types is None:
+ # Create layer_types based on the alternating pattern
+ self.layer_types = []
+ for layer_id in range(num_hidden_layers):
+ if layer_id % global_attn_every_n_layers != 0:
+ self.layer_types.append("sliding_attention")
+ else:
+ self.layer_types.append("full_attention")
+
+ # NOTE: sliding window numbers matches ModernBERT but is only half of it
+ self.sliding_window = local_attention // 2 if local_attention else -1
+
+
+class ModernBertDecoderEmbeddings(ModernBertEmbeddings):
+ pass
+
+
+class ModernBertDecoderMLP(ModernBertMLP):
+ pass
+
+
+class ModernBertDecoderRotaryEmbedding(ModernBertRotaryEmbedding):
+ pass
+
+
+def eager_attention_forward(
+ module: "ModernBertDecoderAttention",
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ dropout: float = 0.0,
+ scaling: Optional[float] = None,
+ sliding_window: Optional[int] = None,
+ **kwargs,
+) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """A simple eager attention implementation for ModernBERT decoder."""
+ if scaling is None:
+ scaling = module.head_dim**-0.5
+
+ # Compute attention scores
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+
+ # Use the pre-computed attention mask
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, attn_weights
+
+
+class ModernBertDecoderAttention(nn.Module):
+ """Performs causal multi-headed self attention for ModernBERT decoder.
+
+ It supports both local attention (sliding window) and global attention patterns.
+ """
+
+ def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = config.hidden_size // config.num_attention_heads
+ self.num_heads = config.num_attention_heads
+ self.all_head_size = self.head_dim * self.num_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = self.config.attention_dropout
+ self.is_causal = True
+
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
+ )
+
+ # NOTE: this is different than ModernBERT (separated QKV) so be sure to adapt to this
+ self.q_proj = nn.Linear(self.config.hidden_size, self.all_head_size, bias=self.config.attention_bias)
+ self.k_proj = nn.Linear(self.config.hidden_size, self.all_head_size, bias=self.config.attention_bias)
+ self.v_proj = nn.Linear(self.config.hidden_size, self.all_head_size, bias=self.config.attention_bias)
+
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
+ self.out_drop = nn.Dropout(config.attention_dropout)
+
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=self.attention_dropout if self.training else 0.0,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.out_drop(self.Wo(attn_output))
+ return attn_output, attn_weights
+
+
+class ModernBertDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.attention_type = config.layer_types[layer_idx]
+ self.attn_norm = (
+ nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ if layer_idx != 0
+ else nn.Identity()
+ )
+ self.attn = ModernBertDecoderAttention(config=config, layer_idx=layer_idx)
+ self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ self.mlp = ModernBertDecoderMLP(config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings_global: torch.Tensor,
+ position_embeddings_local: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+ hidden_states = self.attn_norm(hidden_states)
+
+ # apply global RoPE to non-sliding layer only
+ if self.attn.is_sliding:
+ position_embeddings = position_embeddings_local
+ else:
+ position_embeddings = position_embeddings_global
+
+ # Self Attention
+ attn_outputs = self.attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = attn_outputs[0]
+
+ # Add residual connection
+ hidden_states = residual + hidden_states
+
+ # MLP
+ residual = hidden_states
+ hidden_states = self.mlp_norm(hidden_states)
+ mlp_output = self.mlp(hidden_states)
+ hidden_states = residual + mlp_output
+ return hidden_states
+
+
+class ModernBertDecoderPredictionHead(ModernBertPredictionHead):
+ pass
+
+
+@auto_docstring
+class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
+ _skip_keys_device_placement = ["past_key_values"]
+ _no_split_modules = ["ModernBertDecoderLayer"]
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": ModernBertDecoderLayer,
+ "attentions": ModernBertDecoderAttention,
+ }
+
+ def _init_weights(self, module: nn.Module):
+ cutoff_factor = self.config.initializer_cutoff_factor
+ if cutoff_factor is None:
+ cutoff_factor = 3
+
+ def init_weight(module: nn.Module, std: float):
+ nn.init.trunc_normal_(
+ module.weight,
+ mean=0.0,
+ std=std,
+ a=-cutoff_factor * std,
+ b=cutoff_factor * std,
+ )
+
+ if isinstance(module, nn.Linear):
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ stds = {
+ "in": self.config.initializer_range,
+ "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
+ "embedding": self.config.initializer_range,
+ "final_out": self.config.hidden_size**-0.5,
+ }
+
+ if isinstance(module, ModernBertDecoderEmbeddings):
+ init_weight(module.tok_embeddings, stds["embedding"])
+ elif isinstance(module, ModernBertDecoderMLP):
+ init_weight(module.Wi, stds["in"])
+ init_weight(module.Wo, stds["out"])
+ elif isinstance(module, ModernBertDecoderAttention):
+ init_weight(module.q_proj, stds["in"])
+ init_weight(module.k_proj, stds["in"])
+ init_weight(module.v_proj, stds["in"])
+ init_weight(module.Wo, stds["out"])
+ elif isinstance(module, ModernBertDecoderPredictionHead):
+ init_weight(module.dense, stds["out"])
+ elif isinstance(module, ModernBertDecoderForSequenceClassification):
+ init_weight(module.classifier, stds["final_out"])
+ elif isinstance(module, ModernBertDecoderForCausalLM):
+ init_weight(module.decoder, stds["out"])
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ def _check_and_adjust_attn_implementation(self, attn_implementation, is_init_check):
+ raise AttributeError("No need to inherit!")
+
+ def _maybe_set_compile(self):
+ raise AttributeError("No need to inherit!")
+
+ def resize_token_embeddings(self, *args, **kwargs):
+ raise AttributeError("No need to inherit!")
+
+
+@auto_docstring
+class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
+ def __init__(self, config: ModernBertDecoderConfig):
+ super().__init__(config)
+ self.config = config
+ self.embeddings = ModernBertDecoderEmbeddings(config)
+ self.layers = nn.ModuleList(
+ [ModernBertDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
+ self.gradient_checkpointing = False
+
+ self.global_rotary_emb = ModernBertDecoderRotaryEmbedding(config=config)
+ self.local_rotary_emb = ModernBertDecoderRotaryEmbedding(config=config)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.tok_embeddings = value
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
+ if (input_ids is None) == (inputs_embeds is None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ batch_size, seq_length = input_ids.shape[:2]
+ else:
+ batch_size, seq_length = inputs_embeds.shape[:2]
+
+ # Handle past_key_values and cache setup
+ if use_cache and past_key_values is None and not self.training:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + seq_length,
+ device=input_ids.device if input_ids is not None else inputs_embeds.device,
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
+
+ # Calculate embeddings
+ hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ # Prepare mask arguments
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": hidden_states,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings_global = self.global_rotary_emb(hidden_states, position_ids)
+ position_embeddings_local = self.local_rotary_emb(hidden_states, position_ids)
+
+ for idx, decoder_layer in enumerate(self.layers):
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_embeddings_global=position_embeddings_global,
+ position_embeddings_local=position_embeddings_local,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = self.final_norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The ModernBert Decoder Model with a language modeling head on top for causal language modeling (CLM).
+ """
+)
+class ModernBertDecoderForCausalLM(ModernBertDecoderPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["decoder.weight"]
+
+ def __init__(self, config: ModernBertDecoderConfig):
+ super().__init__(config)
+ self.config = config
+ self.model = ModernBertDecoderModel(config)
+ self.lm_head = ModernBertDecoderPredictionHead(config)
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.decoder = new_embeddings
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+ [`~modeling_outputs.CausalLMOutputWithPast`]
+ comprising various elements depending on the configuration and inputs.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, ModernBertDecoderForCausalLM
+
+ >>> model = ModernBertDecoderForCausalLM.from_pretrained("blab-jhu/test-32m-dec")
+ >>> tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
+
+ >>> prompt = "The capital of France is"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=1)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "The capital of France is Paris"
+ ```
+ """
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.decoder(self.lm_head(hidden_states))
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The ModernBert Decoder Model with a sequence classification head on top (linear layer).
+
+ [`ModernBertDecoderForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-1, GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedModel):
+ def __init__(self, config: ModernBertDecoderConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = ModernBertDecoderModel(config)
+
+ self.head = ModernBertDecoderPredictionHead(config)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias)
+ self.drop = torch.nn.Dropout(config.classifier_dropout)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring(checkpoint="blab-jhu/test-32m-dec")
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ **kwargs,
+ )
+ hidden_states = transformer_outputs[0]
+ hidden_states = self.drop(self.head(hidden_states))
+ logits = self.classifier(hidden_states)
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+__all__ = [
+ "ModernBertDecoderConfig",
+ "ModernBertDecoderModel",
+ "ModernBertDecoderPreTrainedModel",
+ "ModernBertDecoderForCausalLM",
+ "ModernBertDecoderForSequenceClassification",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/moonshine/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/moonshine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa3c1870c2a1d6f5d089df5a396d4b049e4a9d4e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/moonshine/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_moonshine import *
+ from .modeling_moonshine import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/moonshine/configuration_moonshine.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/moonshine/configuration_moonshine.py
new file mode 100644
index 0000000000000000000000000000000000000000..270a2e3e484023a5b7422e222905e9f08d01d94f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/moonshine/configuration_moonshine.py
@@ -0,0 +1,230 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/moonshine/modular_moonshine.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_moonshine.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class MoonshineConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MoonshineModel`]. It is used to instantiate a Moonshine
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Moonshine
+ [UsefulSensors/moonshine-tiny](https://huggingface.co/UsefulSensors/moonshine-tiny).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32768):
+ Vocabulary size of the Moonshine model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`MoonshineModel`].
+ hidden_size (`int`, *optional*, defaults to 288):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 1152):
+ Dimension of the MLP representations.
+ encoder_num_hidden_layers (`int`, *optional*, defaults to 6):
+ Number of hidden layers in the Transformer encoder.
+ decoder_num_hidden_layers (`int`, *optional*, defaults to 6):
+ Number of hidden layers in the Transformer decoder.
+ encoder_num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ decoder_num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ encoder_num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `encoder_num_key_value_heads=encoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `encoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ decoder_num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `decoder_num_key_value_heads=decoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `decoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `decoder_num_attention_heads`.
+ pad_head_dim_to_multiple_of (`int`, *optional*):
+ Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain
+ optimized attention implementations.
+ encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder.
+ decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ decoder_start_token_id (`int`, *optional*, defaults to 1):
+ Corresponds to the "<|startoftranscript|>" token, which is automatically used when no `decoder_input_ids`
+ are provided to the `generate` function. It is used to guide the model`s generation process depending on
+ the task.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ partial_rotary_factor (`float`, *optional*, defaults to 0.9):
+ Percentage of the query and keys which will have rotary embedding.
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+ Whether the model is used as an encoder/decoder or not.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Denotes beginning of sequences token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ Denotes end of sequences token id.
+
+ Example:
+
+ ```python
+ >>> from transformers import MoonshineModel, MoonshineConfig
+
+ >>> # Initializing a Moonshine style configuration
+ >>> configuration = MoonshineConfig().from_pretrained("UsefulSensors/moonshine-tiny")
+
+ >>> # Initializing a model from the configuration
+ >>> model = MoonshineModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "moonshine"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {
+ "num_key_value_heads": "encoder_num_key_value_heads",
+ "num_attention_heads": "encoder_num_attention_heads",
+ "num_hidden_layers": "encoder_num_hidden_layers",
+ }
+
+ def __init__(
+ self,
+ vocab_size=32768,
+ hidden_size=288,
+ intermediate_size=1152,
+ encoder_num_hidden_layers=6,
+ decoder_num_hidden_layers=6,
+ encoder_num_attention_heads=8,
+ decoder_num_attention_heads=8,
+ encoder_num_key_value_heads=None,
+ decoder_num_key_value_heads=None,
+ pad_head_dim_to_multiple_of=None,
+ encoder_hidden_act="gelu",
+ decoder_hidden_act="silu",
+ max_position_embeddings=512,
+ initializer_range=0.02,
+ decoder_start_token_id=1,
+ use_cache=True,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ partial_rotary_factor=0.9,
+ is_encoder_decoder=True,
+ attention_bias=False,
+ attention_dropout=0.0,
+ bos_token_id=1,
+ eos_token_id=2,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.encoder_num_hidden_layers = encoder_num_hidden_layers
+ self.decoder_num_hidden_layers = decoder_num_hidden_layers
+ self.encoder_num_attention_heads = encoder_num_attention_heads
+ self.decoder_num_attention_heads = decoder_num_attention_heads
+
+ if encoder_num_key_value_heads is None:
+ encoder_num_key_value_heads = encoder_num_attention_heads
+ self.encoder_num_key_value_heads = encoder_num_key_value_heads
+
+ if decoder_num_key_value_heads is None:
+ decoder_num_key_value_heads = decoder_num_attention_heads
+ self.decoder_num_key_value_heads = decoder_num_key_value_heads
+
+ self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
+
+ self.encoder_hidden_act = encoder_hidden_act
+ self.decoder_hidden_act = decoder_hidden_act
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.decoder_start_token_id = decoder_start_token_id
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.partial_rotary_factor = partial_rotary_factor
+ self.is_encoder_decoder = is_encoder_decoder
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ # Validate the correctness of rotary position embeddings parameters
+ rope_config_validation(self)
+
+ super().__init__(
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ decoder_start_token_id=decoder_start_token_id,
+ **kwargs,
+ )
+
+
+__all__ = ["MoonshineConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/moonshine/modeling_moonshine.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/moonshine/modeling_moonshine.py
new file mode 100644
index 0000000000000000000000000000000000000000..42b66aa185c892a45d16cc6d3980fcbef33b0636
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/moonshine/modeling_moonshine.py
@@ -0,0 +1,1097 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/moonshine/modular_moonshine.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_moonshine.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from transformers.utils.generic import OutputRecorder, check_model_inputs
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_moonshine import MoonshineConfig
+
+
+class MoonshineEncoderMLP(nn.Module):
+ def __init__(self, config, hidden_act):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class MoonshineDecoderMLP(nn.Module):
+ def __init__(self, config, hidden_act):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size * 2)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
+ hidden_states = self.activation_fn(gate) * hidden_states
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., 0::2]
+ x2 = x[..., 1::2]
+ return torch.stack((-x2, x1), dim=-1).flatten(-2)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ # Interleave them instead of usual shape
+ cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
+ sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
+
+ # Keep half or full tensor for later concatenation
+ rotary_dim = cos.shape[-1]
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ # Apply rotary embeddings on the first half or full tensor
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ # Concatenate back to full shape
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
+ return q_embed, k_embed
+
+
+class MoonshineAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ config: MoonshineConfig,
+ layer_idx: int,
+ is_causal: bool,
+ num_attention_heads: int,
+ num_key_value_heads: int,
+ ):
+ super().__init__()
+ config.update({"num_attention_heads": num_attention_heads, "num_key_value_heads": num_key_value_heads})
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = is_causal
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+
+ # Pad head dimension to the next specified multiple.
+ if self.config.pad_head_dim_to_multiple_of is not None:
+ target_multiple = self.config.pad_head_dim_to_multiple_of
+ target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple)
+ self.head_dim_padding = target_head_dim - self.head_dim
+ else:
+ self.head_dim_padding = 0
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ key_value_states: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, q_len = hidden_states.shape[:-1]
+
+ query_states = (
+ self.q_proj(hidden_states).view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2)
+ )
+
+ is_cross_attention = key_value_states is not None
+ if past_key_values is not None:
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ past_key_values.is_updated[self.layer_idx] = True
+ past_key_values = past_key_values.cross_attention_cache
+ else:
+ past_key_values = past_key_values.self_attention_cache
+
+ # use key_value_states if cross attention
+ current_states = key_value_states if key_value_states is not None else hidden_states
+ if is_cross_attention and past_key_values and is_updated:
+ key_states = past_key_values.layers[self.layer_idx].keys
+ value_states = past_key_values.layers[self.layer_idx].values
+ else:
+ key_states = (
+ self.k_proj(current_states)
+ .view(bsz, -1, self.config.num_key_value_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(current_states)
+ .view(bsz, -1, self.config.num_key_value_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ if is_cross_attention and past_key_values is not None:
+ key_states, value_states = past_key_values.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+
+ if not is_cross_attention:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ is_causal = self.is_causal and attention_mask is None and q_len > 1
+
+ if self.head_dim_padding > 0:
+ query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding))
+ key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding))
+ value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding))
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ is_causal=is_causal,
+ **kwargs,
+ )
+
+ if self.head_dim_padding > 0:
+ attn_output = attn_output[..., : -self.head_dim_padding]
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class MoonshineRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: MoonshineConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class MoonshineEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: MoonshineConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = MoonshineAttention(
+ config=config,
+ layer_idx=layer_idx,
+ is_causal=False,
+ num_attention_heads=config.encoder_num_attention_heads,
+ num_key_value_heads=config.encoder_num_key_value_heads,
+ )
+
+ self.mlp = MoonshineEncoderMLP(config, config.encoder_hidden_act)
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class MoonshineDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = MoonshineAttention(
+ config=config,
+ layer_idx=layer_idx,
+ is_causal=True,
+ num_attention_heads=config.decoder_num_attention_heads,
+ num_key_value_heads=config.decoder_num_key_value_heads,
+ )
+ self.encoder_attn = MoonshineAttention(
+ config=config,
+ layer_idx=layer_idx,
+ is_causal=False,
+ num_attention_heads=config.decoder_num_attention_heads,
+ num_key_value_heads=config.decoder_num_key_value_heads,
+ )
+
+ self.mlp = MoonshineDecoderMLP(config, config.decoder_hidden_act)
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
+ self.final_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ encoder_position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ encoder_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states, _ = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class MoonshinePreTrainedModel(PreTrainedModel):
+ config: MoonshineConfig
+ base_model_prefix = "model"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ # TODO arthur, how do we separate when it cross / self coming from different layer?
+
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
+ """
+ Computes the output length of the convolutional layers
+ """
+ output_conv1_length = int((input_lengths - 127) / 64 + 1)
+ output_conv2_length = int((output_conv1_length - 7) / 3 + 1)
+ output_conv3_length = int((output_conv2_length - 3) / 2 + 1)
+
+ return output_conv3_length
+
+
+class MoonshineEncoder(MoonshinePreTrainedModel):
+ """
+ Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoonshineEncoderLayer`]
+
+ Args:
+ config: MoonshineConfig
+ """
+
+ main_input_name = "input_values"
+ _can_record_outputs = {
+ "attentions": MoonshineAttention,
+ "hidden_states": MoonshineEncoderLayer,
+ }
+
+ def __init__(self, config: MoonshineConfig):
+ super().__init__(config)
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.conv1 = nn.Conv1d(1, embed_dim, kernel_size=127, stride=64, bias=False)
+ self.conv2 = nn.Conv1d(embed_dim, 2 * embed_dim, kernel_size=7, stride=3)
+ self.conv3 = nn.Conv1d(2 * embed_dim, embed_dim, kernel_size=3, stride=2)
+ self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=1e-5)
+ self.rotary_emb = MoonshineRotaryEmbedding(config=config)
+
+ self.layers = nn.ModuleList(
+ [MoonshineEncoderLayer(config, idx) for idx in range(config.encoder_num_hidden_layers)]
+ )
+ self.layer_norm = nn.LayerNorm(embed_dim, bias=False)
+ self.gradient_checkpointing = False
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.conv1
+
+ def set_input_embeddings(self, value: nn.Module):
+ self.conv1 = value
+
+ @check_model_inputs
+ def forward(
+ self,
+ input_values: torch.FloatTensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ r"""
+ Args:
+ input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
+ Float values of the raw speech waveform. Raw speech waveform can be
+ obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
+ `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
+ the soundfile library (`pip install soundfile`). To prepare the array into
+ `input_values`, the [`AutoFeatureExtractor`] should be used for padding
+ and conversion into a tensor of type `torch.FloatTensor`.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ """
+ input_values = input_values.unsqueeze(1)
+ hidden_states = nn.functional.tanh(self.conv1(input_values))
+ hidden_states = self.groupnorm(hidden_states)
+ hidden_states = nn.functional.gelu(self.conv2(hidden_states))
+ hidden_states = nn.functional.gelu(self.conv3(hidden_states))
+ hidden_states = hidden_states.permute(0, 2, 1)
+
+ # attention mask downsampling
+ if attention_mask is not None:
+ mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1])
+ downsample_stride = 64 * 3 * 2 # conv strides
+ attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len]
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if (attention_mask == 0.0).any() else None
+ elif self.config._attn_implementation == "sdpa":
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype)
+ else:
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
+
+ position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ )
+
+
+@auto_docstring
+class MoonshineDecoder(MoonshinePreTrainedModel):
+ main_input_name = "input_ids"
+ _can_record_outputs = {
+ "attentions": OutputRecorder(MoonshineAttention, index=1, layer_name="self_attn"),
+ "hidden_states": MoonshineDecoderLayer,
+ "cross_attentions": OutputRecorder(MoonshineAttention, index=1, layer_name="encoder_attn"),
+ }
+
+ def __init__(self, config: MoonshineConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [MoonshineDecoderLayer(config, idx) for idx in range(config.decoder_num_hidden_layers)]
+ )
+ self.norm = nn.LayerNorm(config.hidden_size, bias=False)
+ self.rotary_emb = MoonshineRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ if encoder_attention_mask is not None:
+ mask_len = encoder_hidden_states.shape[-2]
+ downsample_stride = 64 * 3 * 2 # conv strides
+ encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len]
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None
+ elif self.config._attn_implementation == "sdpa":
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
+ )
+ else:
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
+ )
+
+ for decoder_layer in self.layers:
+ hidden_states = decoder_layer(
+ hidden_states,
+ causal_mask,
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ )
+
+
+def _compute_mask_indices(
+ shape: tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.detach().sum(-1).tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
+@auto_docstring
+class MoonshineModel(MoonshinePreTrainedModel):
+ def __init__(self, config: MoonshineConfig):
+ super().__init__(config)
+
+ self.encoder = MoonshineEncoder(config)
+ self.decoder = MoonshineDecoder(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.decoder.embed_tokens = value
+
+ def get_encoder(self):
+ return self.encoder
+
+ def freeze_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the Moonshine encoder so that its parameters will
+ not be updated during training.
+ """
+ self.encoder._freeze_parameters()
+
+ def _mask_input_features(
+ self,
+ input_features: torch.FloatTensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ """
+ Masks extracted features along time axis and/or along feature axis according to
+ [SpecAugment](https://huggingface.co/papers/1904.08779).
+ """
+
+ # `config.apply_spec_augment` can set masking to False
+ if not getattr(self.config, "apply_spec_augment", True):
+ return input_features
+
+ # generate indices & apply SpecAugment along time axis
+ batch_size, hidden_size, sequence_length = input_features.size()
+
+ if self.config.mask_time_prob > 0 and self.training:
+ # generate indices & apply SpecAugment along time axis
+ mask_time_indices = _compute_mask_indices(
+ (batch_size, sequence_length),
+ mask_prob=self.config.mask_time_prob,
+ mask_length=self.config.mask_time_length,
+ attention_mask=attention_mask,
+ min_masks=self.config.mask_time_min_masks,
+ )
+ mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool)
+ mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1)
+ input_features[mask_time_indices] = 0
+
+ if self.config.mask_feature_prob > 0 and self.training:
+ # generate indices & apply SpecAugment along feature axis
+ mask_feature_indices = _compute_mask_indices(
+ (batch_size, hidden_size),
+ mask_prob=self.config.mask_feature_prob,
+ mask_length=self.config.mask_feature_length,
+ min_masks=self.config.mask_feature_min_masks,
+ )
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool)
+ input_features[mask_feature_indices] = 0
+
+ return input_features
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Union[EncoderDecoderCache, tuple[torch.FloatTensor]]] = None,
+ decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
+ decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Seq2SeqModelOutput:
+ r"""
+ input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
+ Float values of the raw speech waveform. Raw speech waveform can be
+ obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
+ `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
+ the soundfile library (`pip install soundfile`). To prepare the array into
+ `input_values`, the [`AutoFeatureExtractor`] should be used for padding
+ and conversion into a tensor of type `torch.FloatTensor`.
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoFeatureExtractor, MoonshineModel
+ >>> from datasets import load_dataset
+
+ >>> model = MoonshineModel.from_pretrained("UsefulSensors/moonshine-tiny")
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("UsefulSensors/moonshine-tiny")
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
+ >>> input_values = inputs.input_values
+ >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
+ >>> last_hidden_state = model(input_values, decoder_input_ids=decoder_input_ids).last_hidden_state
+ >>> list(last_hidden_state.shape)
+ [1, 2, 288]
+ ```
+ """
+ if encoder_outputs is None:
+ encoder_outputs: BaseModelOutput = self.encoder(input_values, attention_mask=attention_mask, **kwargs)
+
+ decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_attention_mask=attention_mask,
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ position_ids=decoder_position_ids,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+@auto_docstring(
+ custom_intro="""
+ The Moonshine Model with a language modeling head. Can be used for automatic speech recognition.
+ """
+)
+class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["proj_out.weight"]
+
+ def __init__(self, config: MoonshineConfig):
+ super().__init__(config)
+ self.model = MoonshineModel(config)
+ self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.model.get_encoder()
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_output_embeddings(self):
+ return self.proj_out
+
+ def set_output_embeddings(self, new_embeddings):
+ self.proj_out = new_embeddings
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.model.get_input_embeddings()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Union[EncoderDecoderCache, tuple[torch.FloatTensor]]] = None,
+ decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
+ decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Seq2SeqLMOutput:
+ r"""
+ input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
+ Float values of the raw speech waveform. Raw speech waveform can be
+ obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
+ `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
+ the soundfile library (`pip install soundfile`). To prepare the array into
+ `input_values`, the [`AutoFeatureExtractor`] should be used for padding
+ and conversion into a tensor of type `torch.FloatTensor`.
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoProcessor, MoonshineForConditionalGeneration
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-tiny")
+ >>> model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny")
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
+ >>> input_values = inputs.input_values
+
+ >>> generated_ids = model.generate(input_values, max_new_tokens=100)
+
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ >>> transcription
+ 'Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
+ ```"""
+
+ if labels is not None:
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ outputs: Seq2SeqModelOutput = self.model(
+ input_values,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ past_key_values=past_key_values,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ decoder_position_ids=decoder_position_ids,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ logits = self.proj_out(outputs.last_hidden_state)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+
+__all__ = ["MoonshineModel", "MoonshinePreTrainedModel", "MoonshineForConditionalGeneration"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/moonshine/modular_moonshine.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/moonshine/modular_moonshine.py
new file mode 100644
index 0000000000000000000000000000000000000000..12b2ee647bb795cc899adf7a2df301438d3d9260
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/moonshine/modular_moonshine.py
@@ -0,0 +1,921 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from transformers.utils.generic import OutputRecorder, check_model_inputs
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...configuration_utils import PretrainedConfig
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from ...modeling_rope_utils import rope_config_validation
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.deprecation import deprecate_kwarg
+from ..glm.modeling_glm import GlmAttention, GlmRotaryEmbedding, apply_rotary_pos_emb
+from ..llama.modeling_llama import LlamaDecoderLayer, LlamaModel, eager_attention_forward
+from ..whisper.modeling_whisper import WhisperModel, shift_tokens_right
+
+
+logger = logging.get_logger(__name__)
+
+
+class MoonshineConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MoonshineModel`]. It is used to instantiate a Moonshine
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Moonshine
+ [UsefulSensors/moonshine-tiny](https://huggingface.co/UsefulSensors/moonshine-tiny).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32768):
+ Vocabulary size of the Moonshine model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`MoonshineModel`].
+ hidden_size (`int`, *optional*, defaults to 288):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 1152):
+ Dimension of the MLP representations.
+ encoder_num_hidden_layers (`int`, *optional*, defaults to 6):
+ Number of hidden layers in the Transformer encoder.
+ decoder_num_hidden_layers (`int`, *optional*, defaults to 6):
+ Number of hidden layers in the Transformer decoder.
+ encoder_num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ decoder_num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ encoder_num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `encoder_num_key_value_heads=encoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `encoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ decoder_num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `decoder_num_key_value_heads=decoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `decoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `decoder_num_attention_heads`.
+ pad_head_dim_to_multiple_of (`int`, *optional*):
+ Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain
+ optimized attention implementations.
+ encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder.
+ decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ decoder_start_token_id (`int`, *optional*, defaults to 1):
+ Corresponds to the "<|startoftranscript|>" token, which is automatically used when no `decoder_input_ids`
+ are provided to the `generate` function. It is used to guide the model`s generation process depending on
+ the task.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ partial_rotary_factor (`float`, *optional*, defaults to 0.9):
+ Percentage of the query and keys which will have rotary embedding.
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+ Whether the model is used as an encoder/decoder or not.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Denotes beginning of sequences token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ Denotes end of sequences token id.
+
+ Example:
+
+ ```python
+ >>> from transformers import MoonshineModel, MoonshineConfig
+
+ >>> # Initializing a Moonshine style configuration
+ >>> configuration = MoonshineConfig().from_pretrained("UsefulSensors/moonshine-tiny")
+
+ >>> # Initializing a model from the configuration
+ >>> model = MoonshineModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "moonshine"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {
+ "num_key_value_heads": "encoder_num_key_value_heads",
+ "num_attention_heads": "encoder_num_attention_heads",
+ "num_hidden_layers": "encoder_num_hidden_layers",
+ }
+
+ def __init__(
+ self,
+ vocab_size=32768,
+ hidden_size=288,
+ intermediate_size=1152,
+ encoder_num_hidden_layers=6,
+ decoder_num_hidden_layers=6,
+ encoder_num_attention_heads=8,
+ decoder_num_attention_heads=8,
+ encoder_num_key_value_heads=None,
+ decoder_num_key_value_heads=None,
+ pad_head_dim_to_multiple_of=None,
+ encoder_hidden_act="gelu",
+ decoder_hidden_act="silu",
+ max_position_embeddings=512,
+ initializer_range=0.02,
+ decoder_start_token_id=1,
+ use_cache=True,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ partial_rotary_factor=0.9,
+ is_encoder_decoder=True,
+ attention_bias=False,
+ attention_dropout=0.0,
+ bos_token_id=1,
+ eos_token_id=2,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.encoder_num_hidden_layers = encoder_num_hidden_layers
+ self.decoder_num_hidden_layers = decoder_num_hidden_layers
+ self.encoder_num_attention_heads = encoder_num_attention_heads
+ self.decoder_num_attention_heads = decoder_num_attention_heads
+
+ if encoder_num_key_value_heads is None:
+ encoder_num_key_value_heads = encoder_num_attention_heads
+ self.encoder_num_key_value_heads = encoder_num_key_value_heads
+
+ if decoder_num_key_value_heads is None:
+ decoder_num_key_value_heads = decoder_num_attention_heads
+ self.decoder_num_key_value_heads = decoder_num_key_value_heads
+
+ self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
+
+ self.encoder_hidden_act = encoder_hidden_act
+ self.decoder_hidden_act = decoder_hidden_act
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.decoder_start_token_id = decoder_start_token_id
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.partial_rotary_factor = partial_rotary_factor
+ self.is_encoder_decoder = is_encoder_decoder
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ # Validate the correctness of rotary position embeddings parameters
+ rope_config_validation(self)
+
+ super().__init__(
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ decoder_start_token_id=decoder_start_token_id,
+ **kwargs,
+ )
+
+
+class MoonshineEncoderMLP(nn.Module):
+ def __init__(self, config, hidden_act):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class MoonshineDecoderMLP(nn.Module):
+ def __init__(self, config, hidden_act):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size * 2)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
+ hidden_states = self.activation_fn(gate) * hidden_states
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class MoonshineAttention(GlmAttention):
+ def __init__(
+ self,
+ config: MoonshineConfig,
+ layer_idx: int,
+ is_causal: bool,
+ num_attention_heads: int,
+ num_key_value_heads: int,
+ ):
+ config.update({"num_attention_heads": num_attention_heads, "num_key_value_heads": num_key_value_heads})
+ super().__init__(config, layer_idx)
+ self.is_causal = is_causal
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+
+ # Pad head dimension to the next specified multiple.
+ if self.config.pad_head_dim_to_multiple_of is not None:
+ target_multiple = self.config.pad_head_dim_to_multiple_of
+ target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple)
+ self.head_dim_padding = target_head_dim - self.head_dim
+ else:
+ self.head_dim_padding = 0
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ key_value_states: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, q_len = hidden_states.shape[:-1]
+
+ query_states = (
+ self.q_proj(hidden_states).view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2)
+ )
+
+ is_cross_attention = key_value_states is not None
+ if past_key_values is not None:
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ past_key_values.is_updated[self.layer_idx] = True
+ past_key_values = past_key_values.cross_attention_cache
+ else:
+ past_key_values = past_key_values.self_attention_cache
+
+ # use key_value_states if cross attention
+ current_states = key_value_states if key_value_states is not None else hidden_states
+ if is_cross_attention and past_key_values and is_updated:
+ key_states = past_key_values.layers[self.layer_idx].keys
+ value_states = past_key_values.layers[self.layer_idx].values
+ else:
+ key_states = (
+ self.k_proj(current_states)
+ .view(bsz, -1, self.config.num_key_value_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(current_states)
+ .view(bsz, -1, self.config.num_key_value_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ if is_cross_attention and past_key_values is not None:
+ key_states, value_states = past_key_values.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+
+ if not is_cross_attention:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ is_causal = self.is_causal and attention_mask is None and q_len > 1
+
+ if self.head_dim_padding > 0:
+ query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding))
+ key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding))
+ value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding))
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ is_causal=is_causal,
+ **kwargs,
+ )
+
+ if self.head_dim_padding > 0:
+ attn_output = attn_output[..., : -self.head_dim_padding]
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class MoonshineRotaryEmbedding(GlmRotaryEmbedding):
+ pass
+
+
+class MoonshineEncoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: MoonshineConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+
+ self.self_attn = MoonshineAttention(
+ config=config,
+ layer_idx=layer_idx,
+ is_causal=False,
+ num_attention_heads=config.encoder_num_attention_heads,
+ num_key_value_heads=config.encoder_num_key_value_heads,
+ )
+
+ self.mlp = MoonshineEncoderMLP(config, config.encoder_hidden_act)
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
+
+
+class MoonshineDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = MoonshineAttention(
+ config=config,
+ layer_idx=layer_idx,
+ is_causal=True,
+ num_attention_heads=config.decoder_num_attention_heads,
+ num_key_value_heads=config.decoder_num_key_value_heads,
+ )
+ self.encoder_attn = MoonshineAttention(
+ config=config,
+ layer_idx=layer_idx,
+ is_causal=False,
+ num_attention_heads=config.decoder_num_attention_heads,
+ num_key_value_heads=config.decoder_num_key_value_heads,
+ )
+
+ self.mlp = MoonshineDecoderMLP(config, config.decoder_hidden_act)
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
+ self.final_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ encoder_position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ encoder_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states, _ = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class MoonshinePreTrainedModel(PreTrainedModel):
+ config: MoonshineConfig
+ base_model_prefix = "model"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ # TODO arthur, how do we separate when it cross / self coming from different layer?
+
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
+ """
+ Computes the output length of the convolutional layers
+ """
+ output_conv1_length = int((input_lengths - 127) / 64 + 1)
+ output_conv2_length = int((output_conv1_length - 7) / 3 + 1)
+ output_conv3_length = int((output_conv2_length - 3) / 2 + 1)
+
+ return output_conv3_length
+
+
+class MoonshineEncoder(MoonshinePreTrainedModel):
+ """
+ Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoonshineEncoderLayer`]
+
+ Args:
+ config: MoonshineConfig
+ """
+
+ main_input_name = "input_values"
+ _can_record_outputs = {
+ "attentions": MoonshineAttention,
+ "hidden_states": MoonshineEncoderLayer,
+ }
+
+ def __init__(self, config: MoonshineConfig):
+ super().__init__(config)
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.conv1 = nn.Conv1d(1, embed_dim, kernel_size=127, stride=64, bias=False)
+ self.conv2 = nn.Conv1d(embed_dim, 2 * embed_dim, kernel_size=7, stride=3)
+ self.conv3 = nn.Conv1d(2 * embed_dim, embed_dim, kernel_size=3, stride=2)
+ self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=1e-5)
+ self.rotary_emb = MoonshineRotaryEmbedding(config=config)
+
+ self.layers = nn.ModuleList(
+ [MoonshineEncoderLayer(config, idx) for idx in range(config.encoder_num_hidden_layers)]
+ )
+ self.layer_norm = nn.LayerNorm(embed_dim, bias=False)
+ self.gradient_checkpointing = False
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.conv1
+
+ def set_input_embeddings(self, value: nn.Module):
+ self.conv1 = value
+
+ @check_model_inputs
+ def forward(
+ self,
+ input_values: torch.FloatTensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ r"""
+ Args:
+ input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
+ Float values of the raw speech waveform. Raw speech waveform can be
+ obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
+ `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
+ the soundfile library (`pip install soundfile`). To prepare the array into
+ `input_values`, the [`AutoFeatureExtractor`] should be used for padding
+ and conversion into a tensor of type `torch.FloatTensor`.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ """
+ input_values = input_values.unsqueeze(1)
+ hidden_states = nn.functional.tanh(self.conv1(input_values))
+ hidden_states = self.groupnorm(hidden_states)
+ hidden_states = nn.functional.gelu(self.conv2(hidden_states))
+ hidden_states = nn.functional.gelu(self.conv3(hidden_states))
+ hidden_states = hidden_states.permute(0, 2, 1)
+
+ # attention mask downsampling
+ if attention_mask is not None:
+ mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1])
+ downsample_stride = 64 * 3 * 2 # conv strides
+ attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len]
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if (attention_mask == 0.0).any() else None
+ elif self.config._attn_implementation == "sdpa":
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype)
+ else:
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
+
+ position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ )
+
+
+class MoonshineDecoder(LlamaModel):
+ main_input_name = "input_ids"
+ _can_record_outputs = {
+ "attentions": OutputRecorder(MoonshineAttention, index=1, layer_name="self_attn"),
+ "hidden_states": MoonshineDecoderLayer,
+ "cross_attentions": OutputRecorder(MoonshineAttention, index=1, layer_name="encoder_attn"),
+ }
+
+ def __init__(self, config: MoonshineConfig):
+ super().__init__(config)
+ self.norm = nn.LayerNorm(config.hidden_size, bias=False)
+ self.layers = nn.ModuleList(
+ [MoonshineDecoderLayer(config, idx) for idx in range(config.decoder_num_hidden_layers)]
+ )
+
+ @check_model_inputs
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ if encoder_attention_mask is not None:
+ mask_len = encoder_hidden_states.shape[-2]
+ downsample_stride = 64 * 3 * 2 # conv strides
+ encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len]
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None
+ elif self.config._attn_implementation == "sdpa":
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
+ )
+ else:
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
+ )
+
+ for decoder_layer in self.layers:
+ hidden_states = decoder_layer(
+ hidden_states,
+ causal_mask,
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ )
+
+
+class MoonshineModel(WhisperModel):
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Union[EncoderDecoderCache, tuple[torch.FloatTensor]]] = None,
+ decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
+ decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Seq2SeqModelOutput:
+ r"""
+ input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
+ Float values of the raw speech waveform. Raw speech waveform can be
+ obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
+ `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
+ the soundfile library (`pip install soundfile`). To prepare the array into
+ `input_values`, the [`AutoFeatureExtractor`] should be used for padding
+ and conversion into a tensor of type `torch.FloatTensor`.
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoFeatureExtractor, MoonshineModel
+ >>> from datasets import load_dataset
+
+ >>> model = MoonshineModel.from_pretrained("UsefulSensors/moonshine-tiny")
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("UsefulSensors/moonshine-tiny")
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
+ >>> input_values = inputs.input_values
+ >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
+ >>> last_hidden_state = model(input_values, decoder_input_ids=decoder_input_ids).last_hidden_state
+ >>> list(last_hidden_state.shape)
+ [1, 2, 288]
+ ```
+ """
+ if encoder_outputs is None:
+ encoder_outputs: BaseModelOutput = self.encoder(input_values, attention_mask=attention_mask, **kwargs)
+
+ decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_attention_mask=attention_mask,
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ position_ids=decoder_position_ids,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Moonshine Model with a language modeling head. Can be used for automatic speech recognition.
+ """
+)
+class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["proj_out.weight"]
+
+ def __init__(self, config: MoonshineConfig):
+ super().__init__(config)
+ self.model = MoonshineModel(config)
+ self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.model.get_encoder()
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_output_embeddings(self):
+ return self.proj_out
+
+ def set_output_embeddings(self, new_embeddings):
+ self.proj_out = new_embeddings
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.model.get_input_embeddings()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Union[EncoderDecoderCache, tuple[torch.FloatTensor]]] = None,
+ decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
+ decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Seq2SeqLMOutput:
+ r"""
+ input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
+ Float values of the raw speech waveform. Raw speech waveform can be
+ obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
+ `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
+ the soundfile library (`pip install soundfile`). To prepare the array into
+ `input_values`, the [`AutoFeatureExtractor`] should be used for padding
+ and conversion into a tensor of type `torch.FloatTensor`.
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoProcessor, MoonshineForConditionalGeneration
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-tiny")
+ >>> model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny")
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
+ >>> input_values = inputs.input_values
+
+ >>> generated_ids = model.generate(input_values, max_new_tokens=100)
+
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ >>> transcription
+ 'Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
+ ```"""
+
+ if labels is not None:
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ outputs: Seq2SeqModelOutput = self.model(
+ input_values,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ past_key_values=past_key_values,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ decoder_position_ids=decoder_position_ids,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ logits = self.proj_out(outputs.last_hidden_state)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+
+__all__ = [
+ "MoonshineConfig",
+ "MoonshineModel",
+ "MoonshinePreTrainedModel",
+ "MoonshineForConditionalGeneration",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mra/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mra/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2963ad0c97ba8c0c5bdc32ff9adca79ffc56cde3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mra/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_mra import *
+ from .modeling_mra import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mra/configuration_mra.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mra/configuration_mra.py
new file mode 100644
index 0000000000000000000000000000000000000000..16b064c98f7e6a9ed4fe72d7149fb867380c904d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mra/configuration_mra.py
@@ -0,0 +1,137 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MRA model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MraConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MraModel`]. It is used to instantiate an MRA
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Mra
+ [uw-madison/mra-base-512-4](https://huggingface.co/uw-madison/mra-base-512-4) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50265):
+ Vocabulary size of the Mra model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`MraModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimension of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 1):
+ The vocabulary size of the `token_type_ids` passed when calling [`MraModel`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`.
+ block_per_row (`int`, *optional*, defaults to 4):
+ Used to set the budget for the high resolution scale.
+ approx_mode (`str`, *optional*, defaults to `"full"`):
+ Controls whether both low and high resolution approximations are used. Set to `"full"` for both low and
+ high resolution and `"sparse"` for only low resolution.
+ initial_prior_first_n_blocks (`int`, *optional*, defaults to 0):
+ The initial number of blocks for which high resolution is used.
+ initial_prior_diagonal_n_blocks (`int`, *optional*, defaults to 0):
+ The number of diagonal blocks for which high resolution is used.
+
+ Example:
+
+ ```python
+ >>> from transformers import MraConfig, MraModel
+
+ >>> # Initializing a Mra uw-madison/mra-base-512-4 style configuration
+ >>> configuration = MraConfig()
+
+ >>> # Initializing a model (with random weights) from the uw-madison/mra-base-512-4 style configuration
+ >>> model = MraModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "mra"
+
+ def __init__(
+ self,
+ vocab_size=50265,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=1,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ position_embedding_type="absolute",
+ block_per_row=4,
+ approx_mode="full",
+ initial_prior_first_n_blocks=0,
+ initial_prior_diagonal_n_blocks=0,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.type_vocab_size = type_vocab_size
+ self.layer_norm_eps = layer_norm_eps
+ self.position_embedding_type = position_embedding_type
+ self.block_per_row = block_per_row
+ self.approx_mode = approx_mode
+ self.initial_prior_first_n_blocks = initial_prior_first_n_blocks
+ self.initial_prior_diagonal_n_blocks = initial_prior_diagonal_n_blocks
+
+
+__all__ = ["MraConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mra/modeling_mra.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mra/modeling_mra.py
new file mode 100644
index 0000000000000000000000000000000000000000..86bee4d09b5af6663283ac2e404bc39ea9b15e3d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/mra/modeling_mra.py
@@ -0,0 +1,1391 @@
+# coding=utf-8
+# Copyright 2023 University of Wisconsin-Madison and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch MRA model."""
+
+import math
+from pathlib import Path
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from torch.utils.cpp_extension import load
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import auto_docstring, is_cuda_platform, is_ninja_available, is_torch_cuda_available, logging
+from .configuration_mra import MraConfig
+
+
+logger = logging.get_logger(__name__)
+
+mra_cuda_kernel = None
+
+
+def load_cuda_kernels():
+ global mra_cuda_kernel
+ src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "mra"
+
+ def append_root(files):
+ return [src_folder / file for file in files]
+
+ src_files = append_root(["cuda_kernel.cu", "cuda_launch.cu", "torch_extension.cpp"])
+
+ mra_cuda_kernel = load("cuda_kernel", src_files, verbose=True)
+
+
+def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block):
+ """
+ Computes maximum values for softmax stability.
+ """
+ if len(sparse_qk_prod.size()) != 4:
+ raise ValueError("sparse_qk_prod must be a 4-dimensional tensor.")
+
+ if len(indices.size()) != 2:
+ raise ValueError("indices must be a 2-dimensional tensor.")
+
+ if sparse_qk_prod.size(2) != 32:
+ raise ValueError("The size of the second dimension of sparse_qk_prod must be 32.")
+
+ if sparse_qk_prod.size(3) != 32:
+ raise ValueError("The size of the third dimension of sparse_qk_prod must be 32.")
+
+ index_vals = sparse_qk_prod.max(dim=-2).values.transpose(-1, -2)
+ index_vals = index_vals.contiguous()
+
+ indices = indices.int()
+ indices = indices.contiguous()
+
+ max_vals, max_vals_scatter = mra_cuda_kernel.index_max(index_vals, indices, query_num_block, key_num_block)
+ max_vals_scatter = max_vals_scatter.transpose(-1, -2)[:, :, None, :]
+
+ return max_vals, max_vals_scatter
+
+
+def sparse_mask(mask, indices, block_size=32):
+ """
+ Converts attention mask to a sparse mask for high resolution logits.
+ """
+ if len(mask.size()) != 2:
+ raise ValueError("mask must be a 2-dimensional tensor.")
+
+ if len(indices.size()) != 2:
+ raise ValueError("indices must be a 2-dimensional tensor.")
+
+ if mask.shape[0] != indices.shape[0]:
+ raise ValueError("mask and indices must have the same size in the zero-th dimension.")
+
+ batch_size, seq_len = mask.shape
+ num_block = seq_len // block_size
+
+ batch_idx = torch.arange(indices.size(0), dtype=torch.long, device=indices.device)
+ mask = mask.reshape(batch_size, num_block, block_size)
+ mask = mask[batch_idx[:, None], (indices % num_block).long(), :]
+
+ return mask
+
+
+def mm_to_sparse(dense_query, dense_key, indices, block_size=32):
+ """
+ Performs Sampled Dense Matrix Multiplication.
+ """
+ batch_size, query_size, dim = dense_query.size()
+ _, key_size, dim = dense_key.size()
+
+ if query_size % block_size != 0:
+ raise ValueError("query_size (size of first dimension of dense_query) must be divisible by block_size.")
+
+ if key_size % block_size != 0:
+ raise ValueError("key_size (size of first dimension of dense_key) must be divisible by block_size.")
+
+ dense_query = dense_query.reshape(batch_size, query_size // block_size, block_size, dim).transpose(-1, -2)
+ dense_key = dense_key.reshape(batch_size, key_size // block_size, block_size, dim).transpose(-1, -2)
+
+ if len(dense_query.size()) != 4:
+ raise ValueError("dense_query must be a 4-dimensional tensor.")
+
+ if len(dense_key.size()) != 4:
+ raise ValueError("dense_key must be a 4-dimensional tensor.")
+
+ if len(indices.size()) != 2:
+ raise ValueError("indices must be a 2-dimensional tensor.")
+
+ if dense_query.size(3) != 32:
+ raise ValueError("The third dimension of dense_query must be 32.")
+
+ if dense_key.size(3) != 32:
+ raise ValueError("The third dimension of dense_key must be 32.")
+
+ dense_query = dense_query.contiguous()
+ dense_key = dense_key.contiguous()
+
+ indices = indices.int()
+ indices = indices.contiguous()
+
+ return mra_cuda_kernel.mm_to_sparse(dense_query, dense_key, indices.int())
+
+
+def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_size=32):
+ """
+ Performs matrix multiplication of a sparse matrix with a dense matrix.
+ """
+ batch_size, key_size, dim = dense_key.size()
+
+ if key_size % block_size != 0:
+ raise ValueError("key_size (size of first dimension of dense_key) must be divisible by block_size.")
+
+ if sparse_query.size(2) != block_size:
+ raise ValueError("The size of the second dimension of sparse_query must be equal to the block_size.")
+
+ if sparse_query.size(3) != block_size:
+ raise ValueError("The size of the third dimension of sparse_query must be equal to the block_size.")
+
+ dense_key = dense_key.reshape(batch_size, key_size // block_size, block_size, dim).transpose(-1, -2)
+
+ if len(sparse_query.size()) != 4:
+ raise ValueError("sparse_query must be a 4-dimensional tensor.")
+
+ if len(dense_key.size()) != 4:
+ raise ValueError("dense_key must be a 4-dimensional tensor.")
+
+ if len(indices.size()) != 2:
+ raise ValueError("indices must be a 2-dimensional tensor.")
+
+ if dense_key.size(3) != 32:
+ raise ValueError("The size of the third dimension of dense_key must be 32.")
+
+ sparse_query = sparse_query.contiguous()
+
+ indices = indices.int()
+ indices = indices.contiguous()
+ dense_key = dense_key.contiguous()
+
+ dense_qk_prod = mra_cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block)
+ dense_qk_prod = dense_qk_prod.transpose(-1, -2).reshape(batch_size, query_num_block * block_size, dim)
+ return dense_qk_prod
+
+
+def transpose_indices(indices, dim_1_block, dim_2_block):
+ return ((indices % dim_2_block) * dim_1_block + torch.div(indices, dim_2_block, rounding_mode="floor")).long()
+
+
+class MraSampledDenseMatMul(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, dense_query, dense_key, indices, block_size):
+ sparse_qk_prod = mm_to_sparse(dense_query, dense_key, indices, block_size)
+ ctx.save_for_backward(dense_query, dense_key, indices)
+ ctx.block_size = block_size
+ return sparse_qk_prod
+
+ @staticmethod
+ def backward(ctx, grad):
+ dense_query, dense_key, indices = ctx.saved_tensors
+ block_size = ctx.block_size
+ query_num_block = dense_query.size(1) // block_size
+ key_num_block = dense_key.size(1) // block_size
+ indices_T = transpose_indices(indices, query_num_block, key_num_block)
+ grad_key = sparse_dense_mm(grad.transpose(-1, -2), indices_T, dense_query, key_num_block)
+ grad_query = sparse_dense_mm(grad, indices, dense_key, query_num_block)
+ return grad_query, grad_key, None, None
+
+ @staticmethod
+ def operator_call(dense_query, dense_key, indices, block_size=32):
+ return MraSampledDenseMatMul.apply(dense_query, dense_key, indices, block_size)
+
+
+class MraSparseDenseMatMul(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, sparse_query, indices, dense_key, query_num_block):
+ sparse_qk_prod = sparse_dense_mm(sparse_query, indices, dense_key, query_num_block)
+ ctx.save_for_backward(sparse_query, indices, dense_key)
+ ctx.query_num_block = query_num_block
+ return sparse_qk_prod
+
+ @staticmethod
+ def backward(ctx, grad):
+ sparse_query, indices, dense_key = ctx.saved_tensors
+ query_num_block = ctx.query_num_block
+ key_num_block = dense_key.size(1) // sparse_query.size(-1)
+ indices_T = transpose_indices(indices, query_num_block, key_num_block)
+ grad_key = sparse_dense_mm(sparse_query.transpose(-1, -2), indices_T, grad, key_num_block)
+ grad_query = mm_to_sparse(grad, dense_key, indices)
+ return grad_query, None, grad_key, None
+
+ @staticmethod
+ def operator_call(sparse_query, indices, dense_key, query_num_block):
+ return MraSparseDenseMatMul.apply(sparse_query, indices, dense_key, query_num_block)
+
+
+class MraReduceSum:
+ @staticmethod
+ def operator_call(sparse_query, indices, query_num_block, key_num_block):
+ batch_size, num_block, block_size, _ = sparse_query.size()
+
+ if len(sparse_query.size()) != 4:
+ raise ValueError("sparse_query must be a 4-dimensional tensor.")
+
+ if len(indices.size()) != 2:
+ raise ValueError("indices must be a 2-dimensional tensor.")
+
+ _, _, block_size, _ = sparse_query.size()
+ batch_size, num_block = indices.size()
+
+ sparse_query = sparse_query.sum(dim=2).reshape(batch_size * num_block, block_size)
+
+ batch_idx = torch.arange(indices.size(0), dtype=torch.long, device=indices.device)
+ global_idxes = (
+ torch.div(indices, key_num_block, rounding_mode="floor").long() + batch_idx[:, None] * query_num_block
+ ).reshape(batch_size * num_block)
+ temp = torch.zeros(
+ (batch_size * query_num_block, block_size), dtype=sparse_query.dtype, device=sparse_query.device
+ )
+ output = temp.index_add(0, global_idxes, sparse_query).reshape(batch_size, query_num_block, block_size)
+
+ output = output.reshape(batch_size, query_num_block * block_size)
+ return output
+
+
+def get_low_resolution_logit(query, key, block_size, mask=None, value=None):
+ """
+ Compute low resolution approximation.
+ """
+ batch_size, seq_len, head_dim = query.size()
+
+ num_block_per_row = seq_len // block_size
+
+ value_hat = None
+ if mask is not None:
+ token_count = mask.reshape(batch_size, num_block_per_row, block_size).sum(dim=-1)
+ query_hat = query.reshape(batch_size, num_block_per_row, block_size, head_dim).sum(dim=-2) / (
+ token_count[:, :, None] + 1e-6
+ )
+ key_hat = key.reshape(batch_size, num_block_per_row, block_size, head_dim).sum(dim=-2) / (
+ token_count[:, :, None] + 1e-6
+ )
+ if value is not None:
+ value_hat = value.reshape(batch_size, num_block_per_row, block_size, head_dim).sum(dim=-2) / (
+ token_count[:, :, None] + 1e-6
+ )
+ else:
+ token_count = block_size * torch.ones(batch_size, num_block_per_row, dtype=torch.float, device=query.device)
+ query_hat = query.reshape(batch_size, num_block_per_row, block_size, head_dim).mean(dim=-2)
+ key_hat = key.reshape(batch_size, num_block_per_row, block_size, head_dim).mean(dim=-2)
+ if value is not None:
+ value_hat = value.reshape(batch_size, num_block_per_row, block_size, head_dim).mean(dim=-2)
+
+ low_resolution_logit = torch.matmul(query_hat, key_hat.transpose(-1, -2)) / math.sqrt(head_dim)
+
+ low_resolution_logit_row_max = low_resolution_logit.max(dim=-1, keepdims=True).values
+
+ if mask is not None:
+ low_resolution_logit = (
+ low_resolution_logit - 1e4 * ((token_count[:, None, :] * token_count[:, :, None]) < 0.5).float()
+ )
+
+ return low_resolution_logit, token_count, low_resolution_logit_row_max, value_hat
+
+
+def get_block_idxes(
+ low_resolution_logit, num_blocks, approx_mode, initial_prior_first_n_blocks, initial_prior_diagonal_n_blocks
+):
+ """
+ Compute the indices of the subset of components to be used in the approximation.
+ """
+ batch_size, total_blocks_per_row, _ = low_resolution_logit.shape
+
+ if initial_prior_diagonal_n_blocks > 0:
+ offset = initial_prior_diagonal_n_blocks // 2
+ temp_mask = torch.ones(total_blocks_per_row, total_blocks_per_row, device=low_resolution_logit.device)
+ diagonal_mask = torch.tril(torch.triu(temp_mask, diagonal=-offset), diagonal=offset)
+ low_resolution_logit = low_resolution_logit + diagonal_mask[None, :, :] * 5e3
+
+ if initial_prior_first_n_blocks > 0:
+ low_resolution_logit[:, :initial_prior_first_n_blocks, :] = (
+ low_resolution_logit[:, :initial_prior_first_n_blocks, :] + 5e3
+ )
+ low_resolution_logit[:, :, :initial_prior_first_n_blocks] = (
+ low_resolution_logit[:, :, :initial_prior_first_n_blocks] + 5e3
+ )
+
+ top_k_vals = torch.topk(
+ low_resolution_logit.reshape(batch_size, -1), num_blocks, dim=-1, largest=True, sorted=False
+ )
+ indices = top_k_vals.indices
+
+ if approx_mode == "full":
+ threshold = top_k_vals.values.min(dim=-1).values
+ high_resolution_mask = (low_resolution_logit >= threshold[:, None, None]).float()
+ elif approx_mode == "sparse":
+ high_resolution_mask = None
+ else:
+ raise ValueError(f"{approx_mode} is not a valid approx_model value.")
+
+ return indices, high_resolution_mask
+
+
+def mra2_attention(
+ query,
+ key,
+ value,
+ mask,
+ num_blocks,
+ approx_mode,
+ block_size=32,
+ initial_prior_first_n_blocks=0,
+ initial_prior_diagonal_n_blocks=0,
+):
+ """
+ Use Mra to approximate self-attention.
+ """
+ if mra_cuda_kernel is None:
+ return torch.zeros_like(query).requires_grad_()
+
+ batch_size, num_head, seq_len, head_dim = query.size()
+ meta_batch = batch_size * num_head
+
+ if seq_len % block_size != 0:
+ raise ValueError("sequence length must be divisible by the block_size.")
+
+ num_block_per_row = seq_len // block_size
+
+ query = query.reshape(meta_batch, seq_len, head_dim)
+ key = key.reshape(meta_batch, seq_len, head_dim)
+ value = value.reshape(meta_batch, seq_len, head_dim)
+
+ if mask is not None:
+ query = query * mask[:, :, None]
+ key = key * mask[:, :, None]
+ value = value * mask[:, :, None]
+
+ if approx_mode == "full":
+ low_resolution_logit, token_count, low_resolution_logit_row_max, value_hat = get_low_resolution_logit(
+ query, key, block_size, mask, value
+ )
+ elif approx_mode == "sparse":
+ with torch.no_grad():
+ low_resolution_logit, token_count, low_resolution_logit_row_max, _ = get_low_resolution_logit(
+ query, key, block_size, mask
+ )
+ else:
+ raise Exception('approx_mode must be "full" or "sparse"')
+
+ with torch.no_grad():
+ low_resolution_logit_normalized = low_resolution_logit - low_resolution_logit_row_max
+ indices, high_resolution_mask = get_block_idxes(
+ low_resolution_logit_normalized,
+ num_blocks,
+ approx_mode,
+ initial_prior_first_n_blocks,
+ initial_prior_diagonal_n_blocks,
+ )
+
+ high_resolution_logit = MraSampledDenseMatMul.operator_call(
+ query, key, indices, block_size=block_size
+ ) / math.sqrt(head_dim)
+ max_vals, max_vals_scatter = sparse_max(high_resolution_logit, indices, num_block_per_row, num_block_per_row)
+ high_resolution_logit = high_resolution_logit - max_vals_scatter
+ if mask is not None:
+ high_resolution_logit = high_resolution_logit - 1e4 * (1 - sparse_mask(mask, indices)[:, :, :, None])
+ high_resolution_attn = torch.exp(high_resolution_logit)
+ high_resolution_attn_out = MraSparseDenseMatMul.operator_call(
+ high_resolution_attn, indices, value, num_block_per_row
+ )
+ high_resolution_normalizer = MraReduceSum.operator_call(
+ high_resolution_attn, indices, num_block_per_row, num_block_per_row
+ )
+
+ if approx_mode == "full":
+ low_resolution_attn = (
+ torch.exp(low_resolution_logit - low_resolution_logit_row_max - 1e4 * high_resolution_mask)
+ * token_count[:, None, :]
+ )
+
+ low_resolution_attn_out = (
+ torch.matmul(low_resolution_attn, value_hat)[:, :, None, :]
+ .repeat(1, 1, block_size, 1)
+ .reshape(meta_batch, seq_len, head_dim)
+ )
+ low_resolution_normalizer = (
+ low_resolution_attn.sum(dim=-1)[:, :, None].repeat(1, 1, block_size).reshape(meta_batch, seq_len)
+ )
+
+ log_correction = low_resolution_logit_row_max.repeat(1, 1, block_size).reshape(meta_batch, seq_len) - max_vals
+ if mask is not None:
+ log_correction = log_correction * mask
+
+ low_resolution_corr = torch.exp(log_correction * (log_correction <= 0).float())
+ low_resolution_attn_out = low_resolution_attn_out * low_resolution_corr[:, :, None]
+ low_resolution_normalizer = low_resolution_normalizer * low_resolution_corr
+
+ high_resolution_corr = torch.exp(-log_correction * (log_correction > 0).float())
+ high_resolution_attn_out = high_resolution_attn_out * high_resolution_corr[:, :, None]
+ high_resolution_normalizer = high_resolution_normalizer * high_resolution_corr
+
+ context_layer = (high_resolution_attn_out + low_resolution_attn_out) / (
+ high_resolution_normalizer[:, :, None] + low_resolution_normalizer[:, :, None] + 1e-6
+ )
+
+ elif approx_mode == "sparse":
+ context_layer = high_resolution_attn_out / (high_resolution_normalizer[:, :, None] + 1e-6)
+ else:
+ raise Exception('config.approx_mode must be "full" or "sparse"')
+
+ if mask is not None:
+ context_layer = context_layer * mask[:, :, None]
+
+ context_layer = context_layer.reshape(batch_size, num_head, seq_len, head_dim)
+
+ return context_layer
+
+
+class MraEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings + 2, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer(
+ "token_type_ids",
+ torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
+ persistent=False,
+ )
+
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class MraSelfAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ kernel_loaded = mra_cuda_kernel is not None
+ if is_torch_cuda_available() and is_cuda_platform() and is_ninja_available() and not kernel_loaded:
+ try:
+ load_cuda_kernels()
+ except Exception as e:
+ logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = (
+ position_embedding_type if position_embedding_type is not None else config.position_embedding_type
+ )
+
+ self.num_block = (config.max_position_embeddings // 32) * config.block_per_row
+ self.num_block = min(self.num_block, int((config.max_position_embeddings // 32) ** 2))
+
+ self.approx_mode = config.approx_mode
+ self.initial_prior_first_n_blocks = config.initial_prior_first_n_blocks
+ self.initial_prior_diagonal_n_blocks = config.initial_prior_diagonal_n_blocks
+
+ def forward(self, hidden_states, attention_mask=None):
+ batch_size, seq_len, _ = hidden_states.shape
+ query_layer = (
+ self.query(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+ key_layer = (
+ self.key(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+ value_layer = (
+ self.value(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+
+ # revert changes made by get_extended_attention_mask
+ attention_mask = 1.0 + attention_mask / 10000.0
+ attention_mask = (
+ attention_mask.squeeze()
+ .repeat(1, self.num_attention_heads, 1)
+ .reshape(batch_size * self.num_attention_heads, seq_len)
+ .int()
+ )
+
+ # The CUDA kernels are most efficient with inputs whose size is a multiple of a GPU's warp size (32). Inputs
+ # smaller than this are padded with zeros.
+ gpu_warp_size = 32
+
+ if self.attention_head_size < gpu_warp_size:
+ pad_size = batch_size, self.num_attention_heads, seq_len, gpu_warp_size - self.attention_head_size
+
+ query_layer = torch.cat([query_layer, torch.zeros(pad_size, device=query_layer.device)], dim=-1)
+ key_layer = torch.cat([key_layer, torch.zeros(pad_size, device=key_layer.device)], dim=-1)
+ value_layer = torch.cat([value_layer, torch.zeros(pad_size, device=value_layer.device)], dim=-1)
+
+ context_layer = mra2_attention(
+ query_layer.float(),
+ key_layer.float(),
+ value_layer.float(),
+ attention_mask.float(),
+ self.num_block,
+ approx_mode=self.approx_mode,
+ initial_prior_first_n_blocks=self.initial_prior_first_n_blocks,
+ initial_prior_diagonal_n_blocks=self.initial_prior_diagonal_n_blocks,
+ )
+
+ if self.attention_head_size < gpu_warp_size:
+ context_layer = context_layer[:, :, :, : self.attention_head_size]
+
+ context_layer = context_layer.reshape(batch_size, self.num_attention_heads, seq_len, self.attention_head_size)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
+class MraSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class MraAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ self.self = MraSelfAttention(config, position_embedding_type=position_embedding_type)
+ self.output = MraSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(self, hidden_states, attention_mask=None):
+ self_outputs = self.self(hidden_states, attention_mask)
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate
+class MraIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOutput
+class MraOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class MraLayer(GradientCheckpointingLayer):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = MraAttention(config)
+ self.add_cross_attention = config.add_cross_attention
+ self.intermediate = MraIntermediate(config)
+ self.output = MraOutput(config)
+
+ def forward(self, hidden_states, attention_mask=None):
+ self_attention_outputs = self.attention(hidden_states, attention_mask)
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class MraEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([MraLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_outputs = layer_module(hidden_states, attention_mask)
+
+ hidden_states = layer_outputs[0]
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+ return BaseModelOutputWithCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ )
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform
+class MraPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Mra
+class MraLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = MraPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def _tie_weights(self):
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Mra
+class MraOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = MraLMPredictionHead(config)
+
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+@auto_docstring
+# Copied from transformers.models.yoso.modeling_yoso.YosoPreTrainedModel with Yoso->Mra,yoso->mra
+class MraPreTrainedModel(PreTrainedModel):
+ config: MraConfig
+ base_model_prefix = "mra"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights"""
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, MraLMPredictionHead):
+ module.bias.data.zero_()
+
+
+@auto_docstring
+class MraModel(MraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = MraEmbeddings(config)
+ self.encoder = MraEncoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithCrossAttentions]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+
+ if not return_dict:
+ return (sequence_output,) + encoder_outputs[1:]
+
+ return BaseModelOutputWithCrossAttentions(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+@auto_docstring
+class MraForMaskedLM(MraPreTrainedModel):
+ _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.mra = MraModel(config)
+ self.cls = MraOnlyMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+ self.cls.predictions.bias = new_embeddings.bias
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.mra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# Copied from transformers.models.yoso.modeling_yoso.YosoClassificationHead with Yoso->Mra
+class MraClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.config = config
+
+ def forward(self, features, **kwargs):
+ x = features[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = ACT2FN[self.config.hidden_act](x)
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+@auto_docstring(
+ custom_intro="""
+ MRA Model transformer with a sequence classification/regression head on top (a linear layer on top of
+ the pooled output) e.g. for GLUE tasks.
+ """
+)
+class MraForSequenceClassification(MraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.mra = MraModel(config)
+ self.classifier = MraClassificationHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.mra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class MraForMultipleChoice(MraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.mra = MraModel(config)
+ self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, MultipleChoiceModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.mra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_state = outputs[0] # (bs * num_choices, seq_len, dim)
+ pooled_output = hidden_state[:, 0] # (bs * num_choices, dim)
+ pooled_output = self.pre_classifier(pooled_output) # (bs * num_choices, dim)
+ pooled_output = nn.ReLU()(pooled_output) # (bs * num_choices, dim)
+ logits = self.classifier(pooled_output)
+
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class MraForTokenClassification(MraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.mra = MraModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.mra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)
+ active_labels = torch.where(
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+ )
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class MraForQuestionAnswering(MraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ config.num_labels = 2
+ self.num_labels = config.num_labels
+
+ self.mra = MraModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, QuestionAnsweringModelOutput]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.mra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1)
+ end_logits = end_logits.squeeze(-1)
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "MraForMaskedLM",
+ "MraForMultipleChoice",
+ "MraForQuestionAnswering",
+ "MraForSequenceClassification",
+ "MraForTokenClassification",
+ "MraLayer",
+ "MraModel",
+ "MraPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..880274309cbab486a90df05548a0e4d3f2ea0925
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_musicgen import *
+ from .modeling_musicgen import *
+ from .processing_musicgen import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen/configuration_musicgen.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen/configuration_musicgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..878cc122f17d31d7d0bea3d9218663cfc585af9a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen/configuration_musicgen.py
@@ -0,0 +1,248 @@
+# coding=utf-8
+# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MusicGen model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto.configuration_auto import AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class MusicgenDecoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of an [`MusicgenDecoder`]. It is used to instantiate a
+ MusicGen decoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the MusicGen
+ [facebook/musicgen-small](https://huggingface.co/facebook/musicgen-small) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 2048):
+ Vocabulary size of the MusicgenDecoder model. Defines the number of different tokens that can be
+ represented by the `inputs_ids` passed when calling [`MusicgenDecoder`].
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 24):
+ Number of decoder layers.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer block.
+ ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Typically, set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ initializer_factor (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
+ for more details.
+ scale_embedding (`bool`, *optional*, defaults to `False`):
+ Scale embeddings by diving by sqrt(hidden_size).
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether the model should return the last key/values attentions (not used by all models)
+ num_codebooks (`int`, *optional*, defaults to 4):
+ The number of parallel codebooks forwarded to the model.
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
+ Whether input and output word embeddings should be tied.
+ audio_channels (`int`, *optional*, defaults to 1
+ Number of channels in the audio data. Either 1 for mono or 2 for stereo. Stereo models generate a separate
+ audio stream for the left/right output channels. Mono models generate a single audio stream output.
+ """
+
+ model_type = "musicgen_decoder"
+ base_config_key = "decoder_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=2048,
+ max_position_embeddings=2048,
+ num_hidden_layers=24,
+ ffn_dim=4096,
+ num_attention_heads=16,
+ layerdrop=0.0,
+ use_cache=True,
+ activation_function="gelu",
+ hidden_size=1024,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ initializer_factor=0.02,
+ scale_embedding=False,
+ num_codebooks=4,
+ audio_channels=1,
+ pad_token_id=2048,
+ bos_token_id=2048,
+ eos_token_id=None,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.ffn_dim = ffn_dim
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.initializer_factor = initializer_factor
+ self.layerdrop = layerdrop
+ self.use_cache = use_cache
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+ self.num_codebooks = num_codebooks
+
+ if audio_channels not in [1, 2]:
+ raise ValueError(f"Expected 1 (mono) or 2 (stereo) audio channels, got {audio_channels} channels.")
+ self.audio_channels = audio_channels
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class MusicgenConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MusicgenModel`]. It is used to instantiate a
+ MusicGen model according to the specified arguments, defining the text encoder, audio encoder and MusicGen decoder
+ configs.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ kwargs (*optional*):
+ Dictionary of keyword arguments. Notably:
+
+ - **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
+ defines the text encoder config.
+ - **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
+ defines the audio encoder config.
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
+ the decoder config.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... MusicgenConfig,
+ ... MusicgenDecoderConfig,
+ ... T5Config,
+ ... EncodecConfig,
+ ... MusicgenForConditionalGeneration,
+ ... )
+
+ >>> # Initializing text encoder, audio encoder, and decoder model configurations
+ >>> text_encoder_config = T5Config()
+ >>> audio_encoder_config = EncodecConfig()
+ >>> decoder_config = MusicgenDecoderConfig()
+
+ >>> configuration = MusicgenConfig.from_sub_models_config(
+ ... text_encoder_config, audio_encoder_config, decoder_config
+ ... )
+
+ >>> # Initializing a MusicgenForConditionalGeneration (with random weights) from the facebook/musicgen-small style configuration
+ >>> model = MusicgenForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ >>> config_text_encoder = model.config.text_encoder
+ >>> config_audio_encoder = model.config.audio_encoder
+ >>> config_decoder = model.config.decoder
+
+ >>> # Saving the model, including its configuration
+ >>> model.save_pretrained("musicgen-model")
+
+ >>> # loading model and config from pretrained folder
+ >>> musicgen_config = MusicgenConfig.from_pretrained("musicgen-model")
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("musicgen-model", config=musicgen_config)
+ ```"""
+
+ model_type = "musicgen"
+ sub_configs = {
+ "text_encoder": AutoConfig,
+ "audio_encoder": AutoConfig,
+ "decoder": MusicgenDecoderConfig,
+ }
+ has_no_defaults_at_init = True
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
+ raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
+
+ text_encoder_config = kwargs.pop("text_encoder")
+ text_encoder_model_type = text_encoder_config.pop("model_type")
+
+ audio_encoder_config = kwargs.pop("audio_encoder")
+ audio_encoder_model_type = audio_encoder_config.pop("model_type")
+
+ decoder_config = kwargs.pop("decoder")
+
+ self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
+ self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
+ self.decoder = MusicgenDecoderConfig(**decoder_config)
+ self.is_encoder_decoder = True
+ self.initializer_factor = self.decoder.initializer_factor
+
+ @classmethod
+ def from_sub_models_config(
+ cls,
+ text_encoder_config: PretrainedConfig,
+ audio_encoder_config: PretrainedConfig,
+ decoder_config: MusicgenDecoderConfig,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a [`MusicgenConfig`] (or a derived class) from text encoder, audio encoder and decoder
+ configurations.
+
+ Returns:
+ [`MusicgenConfig`]: An instance of a configuration object
+ """
+
+ return cls(
+ text_encoder=text_encoder_config.to_dict(),
+ audio_encoder=audio_encoder_config.to_dict(),
+ decoder=decoder_config.to_dict(),
+ **kwargs,
+ )
+
+ @property
+ # This is a property because you might want to change the codec model on the fly
+ def sampling_rate(self):
+ return self.audio_encoder.sampling_rate
+
+
+__all__ = ["MusicgenConfig", "MusicgenDecoderConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen/modeling_musicgen.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen/modeling_musicgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..fec5fabc54703d69246d1e98659104a03187ae77
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen/modeling_musicgen.py
@@ -0,0 +1,2453 @@
+# coding=utf-8
+# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Musicgen model."""
+
+import copy
+import inspect
+import math
+import random
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import (
+ ClassifierFreeGuidanceLogitsProcessor,
+ GenerationConfig,
+ GenerationMixin,
+ GenerationMode,
+ LogitsProcessorList,
+ StoppingCriteriaList,
+)
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from ...modeling_flash_attention_utils import (
+ FlashAttentionKwargs,
+)
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ ModelOutput,
+ Seq2SeqLMOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
+from ...utils.deprecation import deprecate_kwarg
+from ..auto.configuration_auto import AutoConfig
+from ..auto.modeling_auto import AutoModel
+from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+if TYPE_CHECKING:
+ from ...generation.streamers import BaseStreamer
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring
+class MusicgenUnconditionalInput(ModelOutput):
+ r"""
+ encoder_outputs (`tuple[torch.FloatTensor]` of length 1, with tensor shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the text encoder model.
+ attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*):
+ Encoder attention mask to avoid performing attention on padding token indices. Mask values selected in `[0,
+ 1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**.
+ guidance_scale (`float`, *optional*):
+ Guidance scale for classifier free guidance, setting the balance between the conditional logits (predicted
+ from the prompts) and the unconditional logits (predicted without prompts).
+ """
+
+ encoder_outputs: Optional[tuple[torch.FloatTensor]] = None
+ attention_mask: Optional[torch.LongTensor] = None
+ guidance_scale: Optional[float] = None
+
+
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.
+ """
+ # transpose to get (bsz, num_codebooks, seq_len)
+ input_ids = input_ids.transpose(1, 2)
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
+ if decoder_start_token_id is None:
+ raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
+ shifted_input_ids[..., 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+class MusicgenSinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length."""
+
+ def __init__(self, num_positions: int, embedding_dim: int):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.make_weights(num_positions, embedding_dim)
+
+ def make_weights(self, num_embeddings: int, embedding_dim: int):
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim)
+ if hasattr(self, "weights"):
+ # in forward put the weights on the correct dtype and device of the param
+ emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
+
+ self.register_buffer("weights", emb_weights, persistent=False)
+
+ @staticmethod
+ def get_embedding(num_embeddings: int, embedding_dim: int):
+ """
+ Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
+ description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ return emb.to(torch.get_default_dtype())
+
+ @torch.no_grad()
+ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
+ bsz, codebooks, seq_len = input_ids.size()
+ # Create the position ids from the input token ids.
+ position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
+ # expand embeddings if needed
+ if seq_len > self.weights.size(0):
+ self.make_weights(seq_len, self.embedding_dim)
+ return self.weights.index_select(0, position_ids.view(-1)).detach()
+
+
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class MusicgenAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: Optional[float] = 0.0,
+ is_decoder: Optional[bool] = False,
+ bias: Optional[bool] = True,
+ is_causal: Optional[bool] = False,
+ config: Optional[MusicgenConfig] = None,
+ layer_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+ self.layer_idx = layer_idx
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
+ is_updated = False
+ if past_key_values is not None:
+ if isinstance(past_key_values, EncoderDecoderCache):
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_layer from cache
+ curr_past_key_value = past_key_values.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_values.self_attention_cache
+ else:
+ curr_past_key_value = past_key_values
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_values is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
+ value_states = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
+
+ if past_key_values is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
+ past_key_values.is_updated[self.layer_idx] = True
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class MusicgenDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: MusicgenDecoderConfig, layer_idx=None):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+
+ self.self_attn = MusicgenAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ bias=False,
+ is_causal=True,
+ config=config,
+ layer_idx=layer_idx,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.encoder_attn = MusicgenAttention(
+ self.embed_dim,
+ config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ bias=False,
+ config=config,
+ layer_idx=layer_idx,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
+ self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+ size `(decoder_attention_heads,)`.
+ past_key_values (`Cache`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Cross-Attention Block
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ hidden_states, cross_attn_weights = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+ return outputs
+
+
+@auto_docstring
+class MusicgenPreTrainedModel(PreTrainedModel):
+ config: MusicgenDecoderConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_factor
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class MusicgenDecoder(MusicgenPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MusicgenDecoderLayer`]
+ """
+
+ def __init__(self, config: MusicgenDecoderConfig):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.layerdrop
+ self.max_target_positions = config.max_position_embeddings
+ self.d_model = config.hidden_size
+ self.num_codebooks = config.num_codebooks
+ self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
+
+ embed_dim = config.vocab_size + 1
+ self.embed_tokens = nn.ModuleList(
+ [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
+ )
+
+ self.embed_positions = MusicgenSinusoidalPositionalEmbedding(
+ config.max_position_embeddings,
+ config.hidden_size,
+ )
+
+ self.layers = nn.ModuleList(
+ [MusicgenDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
+ )
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
+ self.attn_implementation = config._attn_implementation
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
+
+ Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
+ such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+
+
+ The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
+ target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
+ you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
+ frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
+ target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
+ `input_ids`.
+
+
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
+ the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ # (bsz * codebooks, seq_len) -> (bsz, codebooks, seq_len)
+ input = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
+ bsz, num_codebooks, seq_len = input.shape
+ input_shape = (bsz, seq_len)
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1:]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if use_cache and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+ if use_cache and isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
+
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
+
+ attention_mask = self._update_causal_mask(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
+ )
+
+ # embed positions
+ positions = self.embed_positions(input, past_key_values_length)
+ hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {attn_mask.size()[0]}."
+ )
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ dropout_probability = random.uniform(0, 1)
+ if self.training and (dropout_probability < self.layerdrop):
+ continue
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_shape),
+ device=inputs_embeds.device,
+ )
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ return attention_mask
+
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
+
+@auto_docstring
+class MusicgenModel(MusicgenPreTrainedModel):
+ def __init__(self, config: MusicgenDecoderConfig):
+ super().__init__(config)
+ self.decoder = MusicgenDecoder(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.decoder.embed_tokens = value
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
+
+ Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
+ such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+
+
+ The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
+ target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
+ you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
+ frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
+ target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
+ `input_ids`.
+
+
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
+ the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ head_mask=head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ if not return_dict:
+ return decoder_outputs
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ hidden_states=decoder_outputs.hidden_states,
+ attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The MusicGen decoder model with a language modelling head on top.
+ """
+)
+class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin):
+ def __init__(self, config: MusicgenDecoderConfig):
+ super().__init__(config)
+
+ self.model = MusicgenModel(config)
+
+ self.num_codebooks = config.num_codebooks
+ self.lm_heads = nn.ModuleList(
+ [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)]
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_heads
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_heads = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.decoder = decoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
+
+ Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
+ such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+
+
+ The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
+ target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
+ you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
+ frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
+ target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
+ `input_ids`.
+
+
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
+ the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (labels is not None) and (input_ids is None and inputs_embeds is None):
+ input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+
+ lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1)
+
+ loss = None
+ if labels is not None:
+ # since encoder hidden states have been concatenated to the decoder hidden states,
+ # we take the last timestamps corresponding to labels
+ logits = lm_logits[:, :, -labels.shape[1] :]
+
+ loss_fct = CrossEntropyLoss()
+ loss = torch.zeros([], device=self.device)
+
+ # per codebook cross-entropy
+ # -100 labels are ignored
+ labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
+
+ # per codebook cross-entropy
+ # ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243
+ for codebook in range(self.config.num_codebooks):
+ codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
+ codebook_labels = labels[..., codebook].contiguous().view(-1)
+ loss += loss_fct(codebook_logits, codebook_labels)
+
+ loss = loss / self.config.num_codebooks
+
+ # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
+ lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ use_cache=True,
+ delay_pattern_mask=None,
+ guidance_scale=None,
+ **kwargs,
+ ):
+ # Overwritten -- MusicGen has custom processing
+ if delay_pattern_mask is None:
+ input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
+ input_ids,
+ pad_token_id=self.generation_config.pad_token_id,
+ max_length=self.generation_config.max_length,
+ )
+
+ # apply the delay pattern mask
+ input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
+
+ if guidance_scale is not None and guidance_scale > 1:
+ # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
+ # before sampling)
+ input_ids = input_ids.repeat((2, 1))
+ if attention_mask is not None:
+ attention_mask = attention_mask.repeat((2, 1))
+
+ if past_key_values is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_attention_mask": encoder_attention_mask,
+ "head_mask": head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+
+ def build_delay_pattern_mask(
+ self, input_ids: torch.LongTensor, pad_token_id: int, max_length: Optional[int] = None
+ ):
+ """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
+ one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
+ are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
+ seq_len)`:
+ - [P, -1, -1, -1, -1, P, P, P]
+ - [P, P, -1, -1, -1, -1, P, P]
+ - [P, P, P, -1, -1, -1, -1, P]
+ - [P, P, P, P, -1, -1, -1, -1]
+ where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
+ a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
+ mask is set to the value in the prompt:
+ - [P, a, b, -1, -1, P, P, P]
+ - [P, P, c, d, -1, -1, P, P]
+ - [P, P, P, e, f, -1, -1, P]
+ - [P, P, P, P, g, h, -1, -1]
+ where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
+ tokens in our prediction.
+ """
+ # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
+ input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
+ bsz, num_codebooks, seq_len = input_ids.shape
+
+ max_length = max_length if max_length is not None else self.generation_config.max_length
+ input_ids_shifted = (
+ torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
+ )
+
+ channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks
+ # we only apply the mask if we have a large enough seq len - otherwise we return as is
+ if max_length < 2 * channel_codebooks - 1:
+ return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
+
+ # fill the shifted ids with the prompt entries, offset by the codebook idx
+ for codebook in range(channel_codebooks):
+ if self.config.audio_channels == 1:
+ # mono channel - loop over the codebooks one-by-one
+ input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
+ else:
+ # left/right channels are interleaved in the generated codebooks, so handle one then the other
+ input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
+ input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]
+
+ # construct a pattern mask that indicates the positions of padding tokens for each codebook
+ # first fill the upper triangular part (the EOS padding)
+ delay_pattern = torch.triu(
+ torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1
+ )
+ # then fill the lower triangular part (the BOS padding)
+ delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool))
+
+ if self.config.audio_channels == 2:
+ # for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
+ delay_pattern = delay_pattern.repeat_interleave(2, dim=0)
+
+ mask = ~delay_pattern.to(input_ids.device)
+ input_ids = mask * input_ids_shifted + ~mask * pad_token_id
+
+ # find the first position to start generating - this is the first place we have the -1 token
+ # and will always be in the first codebook (since it has no codebook offset)
+ first_codebook_ids = input_ids[:, 0, :]
+ start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
+ if len(start_ids) > 0:
+ first_start_id = min(start_ids)
+ else:
+ # we have no tokens that need to be filled - return entire matrix of input ids
+ first_start_id = seq_len
+
+ # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
+ pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
+ input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
+ return input_ids, pattern_mask
+
+ @staticmethod
+ def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
+ """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
+ the mask is set to -1, and otherwise setting to the value detailed in the mask."""
+ seq_len = input_ids.shape[-1]
+ decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
+ input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
+ return input_ids
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ synced_gpus: Optional[bool] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ **kwargs,
+ ):
+ """
+
+ Generates sequences of token ids for models with a language modeling head.
+
+
+
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+ For an overview of generation strategies and code examples, check out the [following
+ guide](./generation_strategies).
+
+
+
+ Parameters:
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+ should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which had the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
+ streamer (`BaseStreamer`, *optional*):
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ kwargs (`dict[str, Any]`, *optional*):
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+ Return:
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
+
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GenerateDecoderOnlyOutput`],
+ - [`~generation.GenerateBeamDecoderOnlyOutput`]
+
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GenerateEncoderDecoderOutput`],
+ - [`~generation.GenerateBeamEncoderDecoderOutput`]
+ """
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
+ if generation_config is None:
+ generation_config = self.generation_config
+
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
+ generation_config.validate()
+ self._validate_model_kwargs(model_kwargs.copy())
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
+ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
+
+ # 3. Define model inputs`
+ input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ batch_size = input_ids.shape[0] // self.num_codebooks
+ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
+
+ # 4. Define other model kwargs
+ model_kwargs["use_cache"] = generation_config.use_cache
+ model_kwargs["guidance_scale"] = generation_config.guidance_scale
+
+ if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
+ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
+ input_ids, generation_config, model_kwargs
+ )
+
+ # 5. Prepare `max_length` depending on other stopping criteria.
+ input_ids_length = input_ids.shape[-1]
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
+ generation_config = self._prepare_generated_length(
+ generation_config=generation_config,
+ has_default_max_length=has_default_max_length,
+ has_default_min_length=has_default_min_length,
+ model_input_name=model_input_name,
+ inputs_tensor=input_ids,
+ input_ids_length=input_ids_length,
+ )
+
+ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
+
+ # 6. Prepare the cache.
+ # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
+ # - different models have a different cache name expected by the model (default = "past_key_values")
+ # - `max_length`, prepared above, is used to determine the maximum cache length
+ max_cache_length = generation_config.max_length - 1
+ if (
+ input_ids_length.shape[1] != input_ids_length
+ and model_input_name == "inputs_embeds"
+ and not self.config.is_encoder_decoder
+ ):
+ max_cache_length += input_ids_length.shape[1]
+ self._prepare_cache_for_generation(
+ generation_config,
+ model_kwargs,
+ generation_mode=None,
+ batch_size=batch_size,
+ max_cache_length=max_cache_length,
+ )
+
+ # 7. Prepare `input_ids` which will be used for auto-regressive generation
+ # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
+ input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
+ input_ids,
+ pad_token_id=generation_config._decoder_start_token_tensor,
+ max_length=generation_config.max_length,
+ )
+
+ if streamer is not None:
+ streamer.put(input_ids.cpu())
+
+ # stash the delay mask so that we don't have to recompute it in each forward pass
+ model_kwargs["delay_pattern_mask"] = delay_pattern_mask
+
+ # 8. determine generation mode
+ generation_mode = generation_config.get_generation_mode()
+
+ # 9. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
+ if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
+ logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
+ generation_config.guidance_scale = None
+
+ # 10. prepare distribution pre_processing samplers
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_length,
+ encoder_input_ids=input_ids,
+ prefix_allowed_tokens_fn=None,
+ logits_processor=logits_processor,
+ device=input_ids.device,
+ )
+
+ # 10. prepare stopping criteria
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+
+ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
+ # expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ **model_kwargs,
+ )
+
+ # 11. run sample
+ outputs = self._sample(
+ input_ids,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=synced_gpus,
+ streamer=streamer,
+ **model_kwargs,
+ )
+
+ else:
+ raise ValueError(
+ "Got incompatible mode for generation, should be one of greedy or sampling. "
+ "Ensure that beam search is de-activated by setting `num_beams=1`."
+ )
+
+ if generation_config.return_dict_in_generate:
+ output_ids = outputs.sequences
+ else:
+ output_ids = outputs
+
+ # apply the pattern mask to the final ids
+ output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
+
+ # revert the pattern delay mask by filtering the pad token id
+ output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
+ batch_size, self.num_codebooks, -1
+ )
+
+ if generation_config.return_dict_in_generate:
+ outputs.sequences = output_ids
+ return outputs
+ else:
+ return output_ids
+
+
+@auto_docstring(
+ custom_intro="""
+ The composite MusicGen model with a text encoder, audio encoder and Musicgen decoder,
+ """
+)
+class MusicgenForConditionalGeneration(MusicgenPreTrainedModel, GenerationMixin):
+ config: MusicgenConfig
+ base_model_prefix = "encoder_decoder"
+ main_input_name = "input_ids"
+ supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ config: Optional[MusicgenConfig] = None,
+ text_encoder: Optional[PreTrainedModel] = None,
+ audio_encoder: Optional[PreTrainedModel] = None,
+ decoder: Optional[MusicgenForCausalLM] = None,
+ ):
+ r"""
+ text_encoder (`PreTrainedModel`, *optional*):
+ The text encoder model that encodes text into hidden states for conditioning.
+ audio_encoder (`PreTrainedModel`, *optional*):
+ The audio encoder model that encodes audio into hidden states for conditioning.
+ decoder (`MusicgenForCausalLM`, *optional*):
+ The decoder model that generates audio tokens based on conditioning signals.
+ """
+ if config is None and (text_encoder is None or audio_encoder is None or decoder is None):
+ raise ValueError(
+ "Either a configuration has to be provided, or all three of text encoder, audio encoder and MusicGen decoder."
+ )
+ if config is None:
+ config = MusicgenConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config)
+ else:
+ if not isinstance(config, self.config_class):
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
+
+ if config.decoder.cross_attention_hidden_size is not None:
+ if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size:
+ raise ValueError(
+ "If `cross_attention_hidden_size` is specified in the MusicGen decoder's configuration, it has to be equal"
+ f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for"
+ " `config.text_encoder.hidden_size`."
+ )
+
+ # initialize with config
+ super().__init__(config)
+
+ if text_encoder is None:
+ from ..auto.modeling_auto import AutoModelForTextEncoding
+
+ text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder)
+
+ if audio_encoder is None:
+ from ..auto.modeling_auto import AutoModel
+
+ audio_encoder = AutoModel.from_config(config.audio_encoder)
+
+ if decoder is None:
+ decoder = MusicgenForCausalLM._from_config(config.decoder)
+
+ self.text_encoder = text_encoder
+ self.audio_encoder = audio_encoder
+ self.decoder = decoder
+
+ if self.text_encoder.config.to_dict() != self.config.text_encoder.to_dict():
+ logger.warning(
+ f"Config of the text_encoder: {self.text_encoder.__class__} is overwritten by shared text_encoder config:"
+ f" {self.config.text_encoder}"
+ )
+ if self.audio_encoder.config.to_dict() != self.config.audio_encoder.to_dict():
+ logger.warning(
+ f"Config of the audio_encoder: {self.audio_encoder.__class__} is overwritten by shared audio_encoder config:"
+ f" {self.config.audio_encoder}"
+ )
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
+ logger.warning(
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
+ )
+
+ # make sure that the individual model's config refers to the shared config
+ # so that the updates to the config will be synced
+ self.config.text_encoder._attn_implementation = self.text_encoder.config._attn_implementation
+ self.config.audio_encoder._attn_implementation = self.audio_encoder.config._attn_implementation
+ self.config.decoder._attn_implementation = self.decoder.config._attn_implementation
+ self.text_encoder.config = self.config.text_encoder
+ self.audio_encoder.config = self.config.audio_encoder
+ self.decoder.config = self.config.decoder
+
+ # text encoder outputs might need to be projected to different dimension for decoder
+ if (
+ self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
+
+ if self.text_encoder.get_output_embeddings() is not None:
+ raise ValueError(
+ f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head"
+ )
+
+ decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys())
+ if "encoder_hidden_states" not in decoder_signature:
+ raise ValueError(
+ "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
+ "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
+ )
+
+ # tie text encoder, decoder weights if config set accordingly
+ self.tie_weights()
+
+ def tie_weights(self):
+ # tie text encoder & decoder if needed
+ if self.config.tie_encoder_decoder:
+ # tie text encoder and decoder base model
+ decoder_base_model_prefix = self.decoder.base_model_prefix
+ tied_weights = self._tie_encoder_decoder_weights(
+ self.text_encoder,
+ self.decoder._modules[decoder_base_model_prefix],
+ self.decoder.base_model_prefix,
+ "text_encoder",
+ )
+ # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
+ # attributed not an instance member, therefore modifying it will modify the entire class
+ # Leading to issues on subsequent calls by different tests or subsequent calls.
+ self._dynamic_tied_weights_keys = tied_weights
+
+ def get_audio_encoder(self):
+ return self.audio_encoder
+
+ def get_text_encoder(self):
+ return self.text_encoder
+
+ def get_encoder(self):
+ # get the text encoder to compute the encoder hidden-states for generation
+ return self.get_text_encoder()
+
+ def get_input_embeddings(self):
+ return self.text_encoder.get_input_embeddings()
+
+ def get_output_embeddings(self):
+ return self.decoder.get_output_embeddings()
+
+ def set_output_embeddings(self, new_embeddings):
+ return self.decoder.set_output_embeddings(new_embeddings)
+
+ @classmethod
+ def from_sub_models_pretrained(
+ cls,
+ text_encoder_pretrained_model_name_or_path: Optional[str] = None,
+ audio_encoder_pretrained_model_name_or_path: Optional[str] = None,
+ decoder_pretrained_model_name_or_path: Optional[str] = None,
+ *model_args,
+ **kwargs,
+ ) -> PreTrainedModel:
+ r"""
+ Instantiate a text encoder, an audio encoder, and a MusicGen decoder from one, two or three base classes of the
+ library from pretrained model checkpoints.
+
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you need to first set it back in training mode with `model.train()`.
+
+ Params:
+ text_encoder_pretrained_model_name_or_path (`str`, *optional*):
+ Information necessary to initiate the text encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ audio_encoder_pretrained_model_name_or_path (`str`, *optional*):
+ Information necessary to initiate the audio encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
+ Information necessary to initiate the decoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ model_args (remaining positional arguments, *optional*):
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`).
+
+ - To update the text encoder configuration, use the prefix *text_encoder_* for each configuration
+ parameter.
+ - To update the audio encoder configuration, use the prefix *audio_encoder_* for each configuration
+ parameter.
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+ Example:
+
+ ```python
+ >>> from transformers import MusicgenForConditionalGeneration
+
+ >>> # initialize a musicgen model from a t5 text encoder, encodec audio encoder, and musicgen decoder
+ >>> model = MusicgenForConditionalGeneration.from_sub_models_pretrained(
+ ... text_encoder_pretrained_model_name_or_path="google-t5/t5-base",
+ ... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz",
+ ... decoder_pretrained_model_name_or_path="facebook/musicgen-small",
+ ... )
+ >>> # saving model after fine-tuning
+ >>> model.save_pretrained("./musicgen-ft")
+ >>> # load fine-tuned model
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("./musicgen-ft")
+ ```"""
+
+ kwargs_text_encoder = {
+ argument[len("text_encoder_") :]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("text_encoder_")
+ }
+
+ kwargs_audio_encoder = {
+ argument[len("audio_encoder_") :]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("audio_encoder_")
+ }
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ # remove text encoder, audio encoder and decoder kwargs from kwargs
+ for key in kwargs_text_encoder:
+ del kwargs["text_encoder_" + key]
+ for key in kwargs_audio_encoder:
+ del kwargs["audio_encoder_" + key]
+ for key in kwargs_decoder:
+ del kwargs["decoder_" + key]
+
+ # Load and initialize the encoder and decoder
+ # The distinction between encoder and decoder at the model level is made
+ # by the value of the flag `is_decoder` that we need to set correctly.
+ text_encoder = kwargs_text_encoder.pop("model", None)
+ if text_encoder is None:
+ if text_encoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `text_encoder_model` is not defined as an argument, a `text_encoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_text_encoder:
+ encoder_config, kwargs_text_encoder = AutoConfig.from_pretrained(
+ text_encoder_pretrained_model_name_or_path, **kwargs_text_encoder, return_unused_kwargs=True
+ )
+
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+ logger.info(
+ f"Initializing {text_encoder_pretrained_model_name_or_path} as a text_encoder model "
+ "from a decoder model. Cross-attention and causal mask are disabled."
+ )
+ encoder_config.is_decoder = False
+ encoder_config.add_cross_attention = False
+
+ kwargs_text_encoder["config"] = encoder_config
+
+ text_encoder = AutoModel.from_pretrained(
+ text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder
+ )
+
+ audio_encoder = kwargs_audio_encoder.pop("model", None)
+ if audio_encoder is None:
+ if audio_encoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `audio_encoder_model` is not defined as an argument, an `audio_encoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_audio_encoder:
+ encoder_config, kwargs_audio_encoder = AutoConfig.from_pretrained(
+ audio_encoder_pretrained_model_name_or_path, **kwargs_audio_encoder, return_unused_kwargs=True
+ )
+
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+ logger.info(
+ f"Initializing {audio_encoder_pretrained_model_name_or_path} as an audio_encoder model "
+ "from a decoder model. Cross-attention and causal mask are disabled."
+ )
+ encoder_config.is_decoder = False
+ encoder_config.add_cross_attention = False
+
+ kwargs_audio_encoder["config"] = encoder_config
+
+ audio_encoder = AutoModel.from_pretrained(
+ audio_encoder_pretrained_model_name_or_path, *model_args, **kwargs_audio_encoder
+ )
+
+ decoder = kwargs_decoder.pop("model", None)
+ if decoder is None:
+ if decoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_decoder:
+ decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
+ )
+
+ if isinstance(decoder_config, MusicgenConfig):
+ decoder_config = decoder_config.decoder
+
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
+ logger.info(
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
+ )
+ decoder_config.is_decoder = True
+ decoder_config.add_cross_attention = True
+
+ kwargs_decoder["config"] = decoder_config
+
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
+ logger.warning(
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
+ "passed to `.from_sub_models_pretrained(...)` are set to `True` or do not pass a "
+ "`decoder_config` to `.from_sub_models_pretrained(...)`"
+ )
+
+ decoder = MusicgenForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
+
+ # instantiate config with corresponding kwargs
+ config = MusicgenConfig.from_sub_models_config(
+ text_encoder.config, audio_encoder.config, decoder.config, **kwargs
+ )
+ return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ input_values: Optional[torch.FloatTensor] = None,
+ padding_mask: Optional[torch.BoolTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ encoder_outputs: Optional[tuple[torch.FloatTensor]] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, Seq2SeqLMOutput]:
+ r"""
+ padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
+
+ Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
+ such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+
+
+ The `decoder_input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
+ target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
+ you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
+ frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
+ target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
+ `decoder_input_ids`.
+
+
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+
+ Examples:
+ ```python
+ >>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
+ >>> import torch
+
+ >>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+
+ >>> inputs = processor(
+ ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
+ ... padding=True,
+ ... return_tensors="pt",
+ ... )
+
+ >>> pad_token_id = model.generation_config.pad_token_id
+ >>> decoder_input_ids = (
+ ... torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)
+ ... * pad_token_id
+ ... )
+
+ >>> logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits
+ >>> logits.shape # (bsz * num_codebooks, tgt_len, vocab_size)
+ torch.Size([8, 1, 2048])
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ kwargs_text_encoder = {
+ argument[len("text_encoder_")]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("text_encoder_")
+ }
+
+ kwargs_audio_encoder = {
+ argument[len("audio_encoder_")]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("audio_encoder_")
+ }
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ if encoder_outputs is None:
+ encoder_outputs = self.text_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ **kwargs_text_encoder,
+ )
+ elif isinstance(encoder_outputs, tuple):
+ encoder_outputs = BaseModelOutput(*encoder_outputs)
+
+ encoder_hidden_states = encoder_outputs[0]
+
+ # optionally project encoder_hidden_states
+ if (
+ self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+ if attention_mask is not None:
+ encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]
+
+ if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id
+ )
+
+ elif decoder_input_ids is None and decoder_inputs_embeds is None:
+ audio_encoder_outputs = self.audio_encoder(
+ input_values=input_values,
+ padding_mask=padding_mask,
+ **kwargs_audio_encoder,
+ )
+ audio_codes = audio_encoder_outputs.audio_codes
+ frames, bsz, codebooks, seq_len = audio_codes.shape
+ if frames != 1:
+ raise ValueError(
+ f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
+ "disabled by setting `chunk_length=None` in the audio encoder."
+ )
+
+ if self.config.decoder.audio_channels == 2 and audio_codes.shape[2] == self.decoder.num_codebooks // 2:
+ # mono input through encodec that we convert to stereo
+ audio_codes = audio_codes.repeat_interleave(2, dim=2)
+
+ decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=attention_mask,
+ inputs_embeds=decoder_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ past_key_values=past_key_values,
+ return_dict=return_dict,
+ labels=labels,
+ **kwargs_decoder,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqLMOutput(
+ loss=decoder_outputs.loss,
+ logits=decoder_outputs.logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ head_mask=None,
+ decoder_attention_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ decoder_delay_pattern_mask=None,
+ guidance_scale=None,
+ cache_position=None,
+ **kwargs,
+ ):
+ # Overwritten -- MusicGen has custom processing
+ if decoder_delay_pattern_mask is None:
+ decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
+ decoder_input_ids,
+ self.generation_config.pad_token_id,
+ max_length=self.generation_config.max_length,
+ )
+
+ # apply the delay pattern mask
+ decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask)
+
+ if guidance_scale is not None and guidance_scale > 1:
+ # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
+ # before sampling)
+ decoder_input_ids = decoder_input_ids.repeat((2, 1))
+ if decoder_attention_mask is not None:
+ decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
+
+ if past_key_values is not None:
+ if cache_position[-1] >= decoder_input_ids.shape[1]:
+ decoder_input_ids = decoder_input_ids[:, -cache_position.shape[0] :]
+ elif (
+ decoder_input_ids.shape[1] != cache_position.shape[0]
+ ): # Default case (the "else", a no op, is Exception 2)
+ decoder_input_ids = decoder_input_ids[:, cache_position]
+ else:
+ # Default to old behavior: keep only final ID
+ decoder_input_ids = decoder_input_ids[:, -1:]
+
+ return {
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache,
+ }
+
+ def _prepare_decoder_input_ids_for_generation(
+ self,
+ batch_size: int,
+ model_input_name: str,
+ model_kwargs: dict[str, torch.Tensor],
+ decoder_start_token_id: Optional[int] = None,
+ bos_token_id: Optional[int] = None,
+ device: Optional[torch.device] = None,
+ ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
+ """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
+
+ # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
+ # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
+ if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
+ decoder_input_ids = model_kwargs.pop("decoder_input_ids")
+ elif "input_ids" in model_kwargs and model_input_name != "input_ids":
+ decoder_input_ids = model_kwargs.pop("input_ids")
+ else:
+ decoder_input_ids = None
+
+ # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
+ decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
+ if device is None:
+ device = self.device
+ decoder_input_ids_start = (
+ torch.ones((batch_size * self.decoder.num_codebooks, 1), dtype=torch.long, device=device)
+ * decoder_start_token_id
+ )
+
+ # no user input -> use decoder_start_token_id as decoder_input_ids
+ if decoder_input_ids is None:
+ decoder_input_ids = decoder_input_ids_start
+
+ # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
+ # decoder_attention_mask if provided)
+ elif (decoder_input_ids[..., 0] != decoder_start_token_id).all().item():
+ decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
+ if "decoder_attention_mask" in model_kwargs:
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+ decoder_attention_mask = torch.cat(
+ (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
+ dim=-1,
+ )
+ model_kwargs["decoder_attention_mask"] = decoder_attention_mask
+
+ return decoder_input_ids, model_kwargs
+
+ def _prepare_text_encoder_kwargs_for_generation(
+ self,
+ inputs_tensor: torch.Tensor,
+ model_kwargs,
+ model_input_name: Optional[str],
+ generation_config: GenerationConfig,
+ ) -> dict[str, Any]:
+ # 1. get text encoder
+ encoder = self.get_text_encoder()
+ # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
+ # as the inputs.
+ if hasattr(encoder, "_hf_hook"):
+ encoder._hf_hook.io_same_device = True
+
+ # 2. Prepare encoder args and encoder kwargs from model kwargs.
+ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
+ encoder_kwargs = {
+ argument: value
+ for argument, value in model_kwargs.items()
+ if not any(argument.startswith(p) for p in irrelevant_prefix)
+ }
+ encoder_signature = set(inspect.signature(encoder.forward).parameters)
+ encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
+ if not encoder_accepts_wildcard:
+ encoder_kwargs = {
+ argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
+ }
+ encoder_kwargs["output_attentions"] = generation_config.output_attentions
+ encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+ guidance_scale = generation_config.guidance_scale
+
+ # 3. make sure that encoder returns `ModelOutput`
+ model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name
+ encoder_kwargs["return_dict"] = True
+ encoder_kwargs[model_input_name] = inputs_tensor
+ last_hidden_state = encoder(**encoder_kwargs).last_hidden_state
+
+ # for classifier free guidance we need to add a 'null' input to our encoder hidden states
+ if guidance_scale is not None and guidance_scale > 1:
+ last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0)
+ if "attention_mask" in model_kwargs:
+ model_kwargs["attention_mask"] = torch.concatenate(
+ [model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0
+ )
+
+ model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state)
+
+ return model_kwargs
+
+ def _prepare_audio_encoder_kwargs_for_generation(
+ self, input_values, model_kwargs, model_input_name: Optional[str] = None
+ ):
+ # 1. get audio encoder
+ encoder = self.get_audio_encoder()
+ # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
+ # as the inputs.
+ if hasattr(encoder, "_hf_hook"):
+ encoder._hf_hook.io_same_device = True
+
+ # 2. Prepare encoder args and encoder kwargs from model kwargs.
+ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
+ encoder_kwargs = {
+ argument: value
+ for argument, value in model_kwargs.items()
+ if not any(argument.startswith(p) for p in irrelevant_prefix)
+ }
+ encoder_signature = set(inspect.signature(encoder.forward).parameters)
+ encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
+ if not encoder_accepts_wildcard:
+ encoder_kwargs = {
+ argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
+ }
+
+ # 3. make sure that encoder returns `ModelOutput`
+ model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name
+ encoder_kwargs["return_dict"] = True
+
+ if self.decoder.config.audio_channels == 1:
+ encoder_kwargs[model_input_name] = input_values
+ audio_encoder_outputs = encoder.encode(**encoder_kwargs)
+ audio_codes = audio_encoder_outputs.audio_codes
+ audio_scales = audio_encoder_outputs.audio_scales
+
+ frames, bsz, codebooks, seq_len = audio_codes.shape
+
+ else:
+ if input_values.shape[1] != 2:
+ raise ValueError(
+ f"Expected stereo audio (2-channels) but example has {input_values.shape[1]} channel."
+ )
+
+ encoder_kwargs[model_input_name] = input_values[:, :1, :]
+ audio_encoder_outputs_left = encoder.encode(**encoder_kwargs)
+ audio_codes_left = audio_encoder_outputs_left.audio_codes
+ audio_scales_left = audio_encoder_outputs_left.audio_scales
+
+ encoder_kwargs[model_input_name] = input_values[:, 1:, :]
+ audio_encoder_outputs_right = encoder.encode(**encoder_kwargs)
+ audio_codes_right = audio_encoder_outputs_right.audio_codes
+ audio_scales_right = audio_encoder_outputs_right.audio_scales
+
+ frames, bsz, codebooks, seq_len = audio_codes_left.shape
+ # copy alternating left/right channel codes into stereo codebook
+ audio_codes = audio_codes_left.new_ones((frames, bsz, 2 * codebooks, seq_len))
+
+ audio_codes[:, :, ::2, :] = audio_codes_left
+ audio_codes[:, :, 1::2, :] = audio_codes_right
+
+ if audio_scales_left != [None] or audio_scales_right != [None]:
+ audio_scales = torch.stack([audio_scales_left, audio_scales_right], dim=1)
+ else:
+ audio_scales = [None] * bsz
+
+ if frames != 1:
+ raise ValueError(
+ f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
+ "disabled by setting `chunk_length=None` in the audio encoder."
+ )
+
+ decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
+
+ model_kwargs["decoder_input_ids"] = decoder_input_ids
+ model_kwargs["audio_scales"] = audio_scales
+ return model_kwargs
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id)
+
+ def resize_token_embeddings(self, *args, **kwargs):
+ raise NotImplementedError(
+ "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
+ " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
+ " model.decoder.resize_token_embeddings(...))"
+ )
+
+ def freeze_audio_encoder(self):
+ """
+ Freeze the audio encoder weights.
+ """
+ for param in self.audio_encoder.parameters():
+ param.requires_grad = False
+ self.audio_encoder._requires_grad = False
+
+ def freeze_text_encoder(self):
+ """
+ Freeze the text encoder weights.
+ """
+ for param in self.text_encoder.parameters():
+ param.requires_grad = False
+ self.text_encoder._requires_grad = False
+
+ def _maybe_initialize_input_ids_for_generation(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ bos_token_id: Optional[int] = None,
+ model_kwargs: Optional[dict[str, torch.Tensor]] = None,
+ ) -> torch.LongTensor:
+ """Initializes input ids for generation, if necessary."""
+ if inputs is not None:
+ return inputs
+
+ encoder_outputs = model_kwargs.get("encoder_outputs")
+ if encoder_outputs is not None:
+ # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
+ shape = encoder_outputs[0].size()[:-1]
+ return torch.ones(shape, dtype=torch.long, device=self.device) * -100
+
+ if bos_token_id is None:
+ raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
+
+ # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
+ # soft-prompting or in multimodal implementations built on top of decoder-only language models.
+ batch_size = 1
+ for value in model_kwargs.values():
+ if isinstance(value, torch.Tensor):
+ batch_size = value.shape[0]
+ break
+ return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
+
+ def _get_decoder_start_token_id(
+ self, decoder_start_token_id: Optional[Union[int, list[int]]] = None, bos_token_id: Optional[int] = None
+ ) -> int:
+ decoder_start_token_id = (
+ decoder_start_token_id
+ if decoder_start_token_id is not None
+ else self.generation_config.decoder_start_token_id
+ )
+ bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
+
+ if decoder_start_token_id is not None:
+ return decoder_start_token_id
+ elif bos_token_id is not None:
+ return bos_token_id
+ raise ValueError(
+ "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ synced_gpus: Optional[bool] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ **kwargs,
+ ):
+ """
+
+ Generates sequences of token ids for models with a language modeling head.
+
+
+
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+ For an overview of generation strategies and code examples, check out the [following
+ guide](./generation_strategies).
+
+
+
+ Parameters:
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+ should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which had the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
+ streamer (`BaseStreamer`, *optional*):
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ kwargs (`dict[str, Any]`, *optional*):
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+ Return:
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
+
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GenerateDecoderOnlyOutput`],
+ - [`~generation.GenerateBeamDecoderOnlyOutput`]
+
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GenerateEncoderDecoderOutput`],
+ - [`~generation.GenerateBeamEncoderDecoderOutput`]
+ """
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
+ if generation_config is None:
+ generation_config = self.generation_config
+
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
+ generation_config.validate()
+ self._validate_model_kwargs(model_kwargs.copy())
+
+ if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) is tuple:
+ # wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate
+ model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0])
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
+ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
+
+ # 3. Define model inputs
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ batch_size = inputs_tensor.shape[0]
+ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
+
+ # 4. Define other model kwargs
+ model_kwargs["use_cache"] = generation_config.use_cache
+ model_kwargs["guidance_scale"] = generation_config.guidance_scale
+
+ if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
+ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
+ inputs_tensor, generation_config, model_kwargs
+ )
+
+ if "encoder_outputs" not in model_kwargs:
+ # encoder_outputs are created and added to `model_kwargs`
+ model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
+ inputs_tensor, model_kwargs, model_input_name, generation_config
+ )
+
+ if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs:
+ model_kwargs = self._prepare_audio_encoder_kwargs_for_generation(
+ model_kwargs["input_values"],
+ model_kwargs,
+ )
+
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
+ batch_size=batch_size,
+ model_input_name=model_input_name,
+ model_kwargs=model_kwargs,
+ decoder_start_token_id=generation_config._decoder_start_token_tensor,
+ bos_token_id=generation_config._bos_token_tensor,
+ device=inputs_tensor.device,
+ )
+
+ # 6. Prepare `max_length` depending on other stopping criteria.
+ input_ids_length = input_ids.shape[-1]
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
+ generation_config = self._prepare_generated_length(
+ generation_config=generation_config,
+ has_default_max_length=has_default_max_length,
+ has_default_min_length=has_default_min_length,
+ model_input_name=model_input_name,
+ inputs_tensor=inputs_tensor,
+ input_ids_length=input_ids_length,
+ )
+
+ # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
+ input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
+ input_ids,
+ pad_token_id=generation_config._decoder_start_token_tensor,
+ max_length=generation_config.max_length,
+ )
+ # stash the delay mask so that we don't have to recompute in each forward pass
+ model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask
+
+ # input_ids are ready to be placed on the streamer (if used)
+ if streamer is not None:
+ streamer.put(input_ids.cpu())
+
+ # 7. determine generation mode
+ generation_mode = generation_config.get_generation_mode()
+
+ # 8. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
+ if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
+ logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
+ generation_config.guidance_scale = None
+
+ # 9. prepare distribution pre_processing samplers
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_length,
+ encoder_input_ids=inputs_tensor,
+ prefix_allowed_tokens_fn=None,
+ logits_processor=logits_processor,
+ device=input_ids.device,
+ )
+
+ # 10. prepare stopping criteria
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+
+ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
+ # expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ # 11. run sample
+ outputs = self._sample(
+ input_ids,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=synced_gpus,
+ streamer=streamer,
+ **model_kwargs,
+ )
+
+ else:
+ raise ValueError(
+ "Got incompatible mode for generation, should be one of greedy or sampling. "
+ "Ensure that beam search is de-activated by setting `num_beams=1`."
+ )
+
+ if generation_config.return_dict_in_generate:
+ output_ids = outputs.sequences
+ else:
+ output_ids = outputs
+
+ # apply the pattern mask to the final ids
+ output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
+
+ # revert the pattern delay mask by filtering the pad token id
+ output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
+ batch_size, self.decoder.num_codebooks, -1
+ )
+
+ # append the frame dimension back to the audio codes
+ output_ids = output_ids[None, ...]
+
+ audio_scales = model_kwargs.get("audio_scales")
+ if audio_scales is None:
+ audio_scales = [None] * batch_size
+
+ if self.decoder.config.audio_channels == 1:
+ output_values = self.audio_encoder.decode(
+ output_ids,
+ audio_scales=audio_scales,
+ ).audio_values
+ else:
+ codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales)
+ output_values_left = codec_outputs_left.audio_values
+
+ codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales)
+ output_values_right = codec_outputs_right.audio_values
+
+ output_values = torch.cat([output_values_left, output_values_right], dim=1)
+
+ if generation_config.return_dict_in_generate:
+ outputs.sequences = output_values
+ return outputs
+ else:
+ return output_values
+
+ def get_unconditional_inputs(self, num_samples=1):
+ """
+ Helper function to get null inputs for unconditional generation, enabling the model to be used without the
+ feature extractor or tokenizer.
+
+ Args:
+ num_samples (int, *optional*):
+ Number of audio samples to unconditionally generate.
+ max_new_tokens (int, *optional*):
+ Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of
+ longer inference (since more audio tokens need to be generated per sample).
+
+ Example:
+ ```python
+ >>> from transformers import MusicgenForConditionalGeneration
+
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+
+ >>> # get the unconditional (or 'null') inputs for the model
+ >>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
+ >>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256)
+ ```"""
+ last_hidden_state = torch.zeros(
+ (num_samples, 1, self.config.text_encoder.hidden_size), device=self.device, dtype=self.dtype
+ )
+
+ attention_mask = torch.zeros((num_samples, 1), device=self.device, dtype=torch.long)
+
+ return MusicgenUnconditionalInput(
+ encoder_outputs=(last_hidden_state,),
+ attention_mask=attention_mask,
+ guidance_scale=1.0,
+ )
+
+
+__all__ = ["MusicgenForConditionalGeneration", "MusicgenForCausalLM", "MusicgenModel", "MusicgenPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen/processing_musicgen.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen/processing_musicgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2bcbf373d489a79eaaec6451e90f092439f3f9f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen/processing_musicgen.py
@@ -0,0 +1,113 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Text/audio processor class for MusicGen
+"""
+
+from typing import Any
+
+import numpy as np
+
+from ...processing_utils import ProcessorMixin
+from ...utils import to_numpy
+
+
+class MusicgenProcessor(ProcessorMixin):
+ r"""
+ Constructs a MusicGen processor which wraps an EnCodec feature extractor and a T5 tokenizer into a single processor
+ class.
+
+ [`MusicgenProcessor`] offers all the functionalities of [`EncodecFeatureExtractor`] and [`TTokenizer`]. See
+ [`~MusicgenProcessor.__call__`] and [`~MusicgenProcessor.decode`] for more information.
+
+ Args:
+ feature_extractor (`EncodecFeatureExtractor`):
+ An instance of [`EncodecFeatureExtractor`]. The feature extractor is a required input.
+ tokenizer (`T5Tokenizer`):
+ An instance of [`T5Tokenizer`]. The tokenizer is a required input.
+ """
+
+ feature_extractor_class = "EncodecFeatureExtractor"
+ tokenizer_class = ("T5Tokenizer", "T5TokenizerFast")
+
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+ self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
+
+ def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
+ return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
+
+ def __call__(self, *args, **kwargs):
+ """
+ Forwards the `audio` argument to EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`] and the `text`
+ argument to [`~T5Tokenizer.__call__`]. Please refer to the docstring of the above two methods for more
+ information.
+ """
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor(*args, **kwargs)
+
+ if len(args) > 0:
+ kwargs["audio"] = args[0]
+ return super().__call__(*args, **kwargs)
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method is used to decode either batches of audio outputs from the MusicGen model, or batches of token ids
+ from the tokenizer. In the case of decoding token ids, this method forwards all its arguments to T5Tokenizer's
+ [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information.
+ """
+ audio_values = kwargs.pop("audio", None)
+ padding_mask = kwargs.pop("padding_mask", None)
+
+ if len(args) > 0:
+ audio_values = args[0]
+ args = args[1:]
+
+ if audio_values is not None:
+ return self._decode_audio(audio_values, padding_mask=padding_mask)
+ else:
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def _decode_audio(self, audio_values, padding_mask: Any = None) -> list[np.ndarray]:
+ """
+ This method strips any padding from the audio values to return a list of numpy audio arrays.
+ """
+ audio_values = to_numpy(audio_values)
+ bsz, channels, seq_len = audio_values.shape
+
+ if padding_mask is None:
+ return list(audio_values)
+
+ padding_mask = to_numpy(padding_mask)
+
+ # match the sequence length of the padding mask to the generated audio arrays by padding with the **non-padding**
+ # token (so that the generated audio values are **not** treated as padded tokens)
+ difference = seq_len - padding_mask.shape[-1]
+ padding_value = 1 - self.feature_extractor.padding_value
+ padding_mask = np.pad(padding_mask, ((0, 0), (0, difference)), "constant", constant_values=padding_value)
+
+ audio_values = audio_values.tolist()
+ for i in range(bsz):
+ sliced_audio = np.asarray(audio_values[i])[
+ padding_mask[i][None, :] != self.feature_extractor.padding_value
+ ]
+ audio_values[i] = sliced_audio.reshape(channels, -1)
+
+ return audio_values
+
+
+__all__ = ["MusicgenProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..51456aac76b0e0c5666acadf0d63008a7c3b50d9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_musicgen_melody import *
+ from .modeling_musicgen_melody import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/configuration_musicgen_melody.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/configuration_musicgen_melody.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4285151c4d0189938b3dfa59e0102770610becd
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/configuration_musicgen_melody.py
@@ -0,0 +1,261 @@
+# coding=utf-8
+# Copyright 2024 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Musicgen Melody model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto.configuration_auto import AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class MusicgenMelodyDecoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of an [`MusicgenMelodyDecoder`]. It is used to instantiate a
+ Musicgen Melody decoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the Musicgen Melody
+ [facebook/musicgen-melody](https://huggingface.co/facebook/musicgen-melody) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 2048):
+ Vocabulary size of the MusicgenMelodyDecoder model. Defines the number of different tokens that can be
+ represented by the `inputs_ids` passed when calling [`MusicgenMelodyDecoder`].
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Typically, set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ num_hidden_layers (`int`, *optional*, defaults to 24):
+ Number of decoder layers.
+ ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer block.
+ layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
+ for more details.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether the model should return the last key/values attentions (not used by all models)
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the layers and the pooler layer.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ initializer_factor (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ scale_embedding (`bool`, *optional*, defaults to `False`):
+ Scale embeddings by diving by sqrt(hidden_size).
+ num_codebooks (`int`, *optional*, defaults to 4):
+ The number of parallel codebooks forwarded to the model.
+ audio_channels (`int`, *optional*, defaults to 1):
+ Number of audio channels used by the model (either mono or stereo). Stereo models generate a separate
+ audio stream for the left/right output channels. Mono models generate a single audio stream output.
+ pad_token_id (`int`, *optional*, defaults to 2048): The id of the *padding* token.
+ bos_token_id (`int`, *optional*, defaults to 2048): The id of the *beginning-of-sequence* token.
+ eos_token_id (`int`, *optional*): The id of the *end-of-sequence* token.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie word embeddings with the text encoder.
+ """
+
+ model_type = "musicgen_melody_decoder"
+ base_config_key = "decoder_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=2048,
+ max_position_embeddings=2048,
+ num_hidden_layers=24,
+ ffn_dim=4096,
+ num_attention_heads=16,
+ layerdrop=0.0,
+ use_cache=True,
+ activation_function="gelu",
+ hidden_size=1024,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ initializer_factor=0.02,
+ scale_embedding=False,
+ num_codebooks=4,
+ audio_channels=1,
+ pad_token_id=2048,
+ bos_token_id=2048,
+ eos_token_id=None,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.ffn_dim = ffn_dim
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.initializer_factor = initializer_factor
+ self.layerdrop = layerdrop
+ self.use_cache = use_cache
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+ self.num_codebooks = num_codebooks
+
+ if audio_channels not in [1, 2]:
+ raise ValueError(f"Expected 1 (mono) or 2 (stereo) audio channels, got {audio_channels} channels.")
+ self.audio_channels = audio_channels
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class MusicgenMelodyConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MusicgenMelodyModel`]. It is used to instantiate a
+ Musicgen Melody model according to the specified arguments, defining the text encoder, audio encoder and Musicgen Melody decoder
+ configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the Musicgen Melody
+ [facebook/musicgen-melody](https://huggingface.co/facebook/musicgen-melody) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_chroma (`int`, *optional*, defaults to 12): Number of chroma bins to use.
+ chroma_length (`int`, *optional*, defaults to 235):
+ Maximum chroma duration if audio is used to condition the model. Corresponds to the maximum duration used during training.
+ kwargs (*optional*):
+ Dictionary of keyword arguments. Notably:
+
+ - **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
+ defines the text encoder config.
+ - **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
+ defines the audio encoder config.
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
+ the decoder config.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... MusicgenMelodyConfig,
+ ... MusicgenMelodyDecoderConfig,
+ ... T5Config,
+ ... EncodecConfig,
+ ... MusicgenMelodyForConditionalGeneration,
+ ... )
+
+ >>> # Initializing text encoder, audio encoder, and decoder model configurations
+ >>> text_encoder_config = T5Config()
+ >>> audio_encoder_config = EncodecConfig()
+ >>> decoder_config = MusicgenMelodyDecoderConfig()
+
+ >>> configuration = MusicgenMelodyConfig.from_sub_models_config(
+ ... text_encoder_config, audio_encoder_config, decoder_config
+ ... )
+
+ >>> # Initializing a MusicgenMelodyForConditionalGeneration (with random weights) from the facebook/musicgen-melody style configuration
+ >>> model = MusicgenMelodyForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ >>> config_text_encoder = model.config.text_encoder
+ >>> config_audio_encoder = model.config.audio_encoder
+ >>> config_decoder = model.config.decoder
+
+ >>> # Saving the model, including its configuration
+ >>> model.save_pretrained("musicgen_melody-model")
+
+ >>> # loading model and config from pretrained folder
+ >>> musicgen_melody_config = MusicgenMelodyConfig.from_pretrained("musicgen_melody-model")
+ >>> model = MusicgenMelodyForConditionalGeneration.from_pretrained("musicgen_melody-model", config=musicgen_melody_config)
+ ```"""
+
+ model_type = "musicgen_melody"
+ sub_configs = {
+ "text_encoder": AutoConfig,
+ "audio_encoder": AutoConfig,
+ "decoder": MusicgenMelodyDecoderConfig,
+ }
+ has_no_defaults_at_init = True
+
+ def __init__(
+ self,
+ num_chroma=12,
+ chroma_length=235,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
+ raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
+
+ text_encoder_config = kwargs.pop("text_encoder")
+ text_encoder_model_type = text_encoder_config.pop("model_type")
+
+ audio_encoder_config = kwargs.pop("audio_encoder")
+ audio_encoder_model_type = audio_encoder_config.pop("model_type")
+
+ decoder_config = kwargs.pop("decoder")
+
+ self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
+ self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
+ self.decoder = MusicgenMelodyDecoderConfig(**decoder_config)
+ self.is_encoder_decoder = False
+
+ self.num_chroma = num_chroma
+ self.chroma_length = chroma_length
+
+ @classmethod
+ def from_sub_models_config(
+ cls,
+ text_encoder_config: PretrainedConfig,
+ audio_encoder_config: PretrainedConfig,
+ decoder_config: MusicgenMelodyDecoderConfig,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a [`MusicgenMelodyConfig`] (or a derived class) from text encoder, audio encoder and decoder
+ configurations.
+
+ Returns:
+ [`MusicgenMelodyConfig`]: An instance of a configuration object
+ """
+
+ return cls(
+ text_encoder=text_encoder_config.to_dict(),
+ audio_encoder=audio_encoder_config.to_dict(),
+ decoder=decoder_config.to_dict(),
+ **kwargs,
+ )
+
+ @property
+ # This is a property because you might want to change the codec model on the fly
+ def sampling_rate(self):
+ return self.audio_encoder.sampling_rate
+
+
+__all__ = ["MusicgenMelodyConfig", "MusicgenMelodyDecoderConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec23899e91e90925e00f79e160a9e8d64de11d99
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py
@@ -0,0 +1,336 @@
+# coding=utf-8
+# Copyright 2024 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Feature extractor class for Musicgen Melody
+"""
+
+import copy
+from typing import Any, Optional, Union
+
+import numpy as np
+
+from ...audio_utils import chroma_filter_bank
+from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
+from ...feature_extraction_utils import BatchFeature
+from ...utils import TensorType, is_torch_available, is_torchaudio_available, logging
+from ...utils.import_utils import requires
+
+
+if is_torch_available():
+ import torch
+
+if is_torchaudio_available():
+ import torchaudio
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("torchaudio",))
+class MusicgenMelodyFeatureExtractor(SequenceFeatureExtractor):
+ r"""
+ Constructs a MusicgenMelody feature extractor.
+
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
+
+ This class extracts chroma features from audio processed by [Demucs](https://github.com/adefossez/demucs/tree/main) or
+ directly from raw audio waveform.
+
+ Args:
+ feature_size (`int`, *optional*, defaults to 12):
+ The feature dimension of the extracted features.
+ sampling_rate (`int`, *optional*, defaults to 32000):
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
+ hop_length (`int`, *optional*, defaults to 4096):
+ Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
+ chunk_length (`int`, *optional*, defaults to 30):
+ The maximum number of chunks of `sampling_rate` samples used to trim and pad longer or shorter audio
+ sequences.
+ n_fft (`int`, *optional*, defaults to 16384):
+ Size of the Fourier transform.
+ num_chroma (`int`, *optional*, defaults to 12):
+ Number of chroma bins to use.
+ padding_value (`float`, *optional*, defaults to 0.0):
+ Padding value used to pad the audio.
+ return_attention_mask (`bool`, *optional*, defaults to `False`):
+ Whether to return the attention mask. Can be overwritten when calling the feature extractor.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+
+
+ For Whisper models, `attention_mask` should always be passed for batched inference, to avoid subtle
+ bugs.
+
+
+ stem_indices (`list[int]`, *optional*, defaults to `[3, 2]`):
+ Stem channels to extract if demucs outputs are passed.
+ """
+
+ model_input_names = ["input_features"]
+
+ def __init__(
+ self,
+ feature_size=12,
+ sampling_rate=32000,
+ hop_length=4096,
+ chunk_length=30,
+ n_fft=16384,
+ num_chroma=12,
+ padding_value=0.0,
+ return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
+ stem_indices=[3, 2],
+ **kwargs,
+ ):
+ super().__init__(
+ feature_size=feature_size,
+ sampling_rate=sampling_rate,
+ padding_value=padding_value,
+ return_attention_mask=return_attention_mask,
+ **kwargs,
+ )
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.chunk_length = chunk_length
+ self.n_samples = chunk_length * sampling_rate
+ self.sampling_rate = sampling_rate
+ self.chroma_filters = torch.from_numpy(
+ chroma_filter_bank(sampling_rate=sampling_rate, num_frequency_bins=n_fft, tuning=0, num_chroma=num_chroma)
+ ).float()
+ self.spectrogram = torchaudio.transforms.Spectrogram(
+ n_fft=n_fft, win_length=n_fft, hop_length=hop_length, power=2, center=True, pad=0, normalized=True
+ )
+ self.stem_indices = stem_indices
+
+ def _torch_extract_fbank_features(self, waveform: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the chroma spectrogram of the provided audio using the torchaudio spectrogram implementation and the librosa chroma features.
+ """
+
+ # if wav length is not long enough, pad it
+ wav_length = waveform.shape[-1]
+ if wav_length < self.n_fft:
+ pad = self.n_fft - wav_length
+ rest = 0 if pad % 2 == 0 else 1
+ waveform = torch.nn.functional.pad(waveform, (pad // 2, pad // 2 + rest), "constant", 0)
+
+ # squeeze alongside channel dimension
+ spec = self.spectrogram(waveform).squeeze(1)
+
+ # sum along the frequency dimension
+ raw_chroma = torch.einsum("cf, ...ft->...ct", self.chroma_filters, spec)
+
+ # normalise with max value
+ norm_chroma = torch.nn.functional.normalize(raw_chroma, p=float("inf"), dim=-2, eps=1e-6)
+
+ # transpose time and chroma dimension -> (batch, time, chroma)
+ norm_chroma = norm_chroma.transpose(1, 2)
+
+ # replace max value alongside chroma dimension with 1 and replace the rest with 0
+ idx = norm_chroma.argmax(-1, keepdim=True)
+ norm_chroma[:] = 0
+ norm_chroma.scatter_(dim=-1, index=idx, value=1)
+
+ return norm_chroma
+
+ def _extract_stem_indices(self, audio, sampling_rate=None):
+ """
+ Extracts stems from the output of the [Demucs](https://github.com/adefossez/demucs/tree/main) audio separation model,
+ then converts to mono-channel and resample to the feature extractor sampling rate.
+
+ Args:
+ audio (`torch.Tensor` of shape `(batch_size, num_stems, channel_size, audio_length)`):
+ The output of the Demucs model to be processed.
+ sampling_rate (`int`, *optional*):
+ Demucs sampling rate. If not specified, defaults to `44000`.
+ """
+ sampling_rate = 44000 if sampling_rate is None else sampling_rate
+
+ # extract "vocals" and "others" sources from audio encoder (demucs) output
+ # [batch_size, num_stems, channel_size, audio_length]
+ wav = audio[:, torch.tensor(self.stem_indices)]
+
+ # merge extracted stems to single waveform
+ wav = wav.sum(1)
+
+ # convert to mono-channel waveform
+ wav = wav.mean(dim=1, keepdim=True)
+
+ # resample to model sampling rate
+ # not equivalent to julius.resample
+ if sampling_rate != self.sampling_rate:
+ wav = torchaudio.functional.resample(
+ wav, sampling_rate, self.sampling_rate, rolloff=0.945, lowpass_filter_width=24
+ )
+
+ # [batch_size, 1, audio_length] -> [batch_size, audio_length]
+ wav = wav.squeeze(1)
+
+ return wav
+
+ def __call__(
+ self,
+ audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
+ truncation: bool = True,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_attention_mask: Optional[bool] = None,
+ padding: Optional[str] = True,
+ max_length: Optional[int] = None,
+ sampling_rate: Optional[int] = None,
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Main method to featurize and prepare for the model one or several sequence(s).
+
+ Args:
+ audio (`torch.Tensor`, `np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[torch.Tensor]`, `list[list[float]]`):
+ The sequence or batch of sequences to be padded. Each sequence can be a torch tensor, a numpy array, a list of float
+ values, a list of numpy arrays, a list of torch tensors, or a list of list of float values.
+ If `audio` is the output of Demucs, it has to be a torch tensor of shape `(batch_size, num_stems, channel_size, audio_length)`.
+ Otherwise, it must be mono or stereo channel audio.
+ truncation (`bool`, *optional*, default to `True`):
+ Activates truncation to cut input sequences longer than *max_length* to *max_length*.
+ pad_to_multiple_of (`int`, *optional*, defaults to None):
+ If set will pad the sequence to a multiple of the provided value.
+
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific feature_extractor's default.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+
+ For Musicgen Melody models, audio `attention_mask` is not necessary.
+
+
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ sampling_rate (`int`, *optional*):
+ The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
+ `sampling_rate` at the forward call to prevent silent errors.
+ Note that if `audio` is the output of Demucs, `sampling_rate` must be the sampling rate at which Demucs operates.
+ """
+
+ if sampling_rate is None:
+ logger.warning_once(
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
+ "Failing to do so can result in silent errors that might be hard to debug."
+ )
+
+ if isinstance(audio, torch.Tensor) and len(audio.shape) == 4:
+ logger.warning_once(
+ "`audio` is a 4-dimensional torch tensor and has thus been recognized as the output of `Demucs`. "
+ "If this is not the case, make sure to read Musicgen Melody docstrings and "
+ "to correct `audio` to get the right behaviour."
+ "Link to the docstrings: https://huggingface.co/docs/transformers/main/en/model_doc/musicgen_melody"
+ )
+ audio = self._extract_stem_indices(audio, sampling_rate=sampling_rate)
+ elif sampling_rate is not None and sampling_rate != self.sampling_rate:
+ audio = torchaudio.functional.resample(
+ audio, sampling_rate, self.sampling_rate, rolloff=0.945, lowpass_filter_width=24
+ )
+
+ is_batched = isinstance(audio, (np.ndarray, torch.Tensor)) and len(audio.shape) > 1
+ is_batched = is_batched or (
+ isinstance(audio, (list, tuple)) and (isinstance(audio[0], (torch.Tensor, np.ndarray, tuple, list)))
+ )
+
+ if is_batched and not isinstance(audio[0], torch.Tensor):
+ audio = [torch.tensor(speech, dtype=torch.float32).unsqueeze(-1) for speech in audio]
+ elif is_batched:
+ audio = [speech.unsqueeze(-1) for speech in audio]
+ elif not is_batched and not isinstance(audio, torch.Tensor):
+ audio = torch.tensor(audio, dtype=torch.float32).unsqueeze(-1)
+
+ if isinstance(audio[0], torch.Tensor) and audio[0].dtype is torch.float64:
+ audio = [speech.to(torch.float32) for speech in audio]
+
+ # always return batch
+ if not is_batched:
+ audio = [audio]
+
+ if len(audio[0].shape) == 3:
+ logger.warning_once(
+ "`audio` has been detected as a batch of stereo signals. Will be convert to mono signals. "
+ "If this is an undesired behaviour, make sure to read Musicgen Melody docstrings and "
+ "to correct `audio` to get the right behaviour."
+ "Link to the docstrings: https://huggingface.co/docs/transformers/main/en/model_doc/musicgen_melody"
+ )
+ # convert to mono-channel waveform
+ audio = [stereo.mean(dim=0) for stereo in audio]
+
+ batched_speech = BatchFeature({"input_features": audio})
+
+ padded_inputs = self.pad(
+ batched_speech,
+ padding=padding,
+ max_length=max_length if max_length else self.n_samples,
+ truncation=truncation,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ return_tensors="pt",
+ )
+
+ input_features = self._torch_extract_fbank_features(padded_inputs["input_features"].squeeze(-1))
+
+ padded_inputs["input_features"] = input_features
+
+ if return_attention_mask:
+ # rescale from raw audio length to spectrogram length
+ padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
+
+ if return_tensors is not None:
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
+
+ return padded_inputs
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary. Returns:
+ `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["feature_extractor_type"] = self.__class__.__name__
+ if "mel_filters" in output:
+ del output["mel_filters"]
+ if "window" in output:
+ del output["window"]
+ if "chroma_filters" in output:
+ del output["chroma_filters"]
+ if "spectrogram" in output:
+ del output["spectrogram"]
+ return output
+
+
+__all__ = ["MusicgenMelodyFeatureExtractor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/modeling_musicgen_melody.py
new file mode 100644
index 0000000000000000000000000000000000000000..e432dc3ff62528a5299c5017909487b257160516
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/modeling_musicgen_melody.py
@@ -0,0 +1,2288 @@
+# coding=utf-8
+# Copyright 2024 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Musicgen Melody model."""
+
+import copy
+import inspect
+import math
+import random
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import (
+ ClassifierFreeGuidanceLogitsProcessor,
+ GenerationConfig,
+ GenerationMixin,
+ GenerationMode,
+ LogitsProcessorList,
+ StoppingCriteriaList,
+)
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from ...modeling_flash_attention_utils import (
+ FlashAttentionKwargs,
+)
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
+from ...utils.deprecation import deprecate_kwarg
+from ..auto.configuration_auto import AutoConfig
+from ..auto.modeling_auto import AutoModel, AutoModelForTextEncoding
+from .configuration_musicgen_melody import MusicgenMelodyConfig, MusicgenMelodyDecoderConfig
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+if TYPE_CHECKING:
+ from ...generation.streamers import BaseStreamer
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Musicgen Melody autoregressive outputs.
+ """
+)
+class MusicgenMelodyOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of conditional hidden-states representing the concatenation of the projected text encoder output and the projected audio encoder output.
+ Used as a conditional signal.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ encoder_hidden_states: Optional[torch.FloatTensor] = None
+
+
+# Copied from transformers.models.musicgen.modeling_musicgen.shift_tokens_right
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.
+ """
+ # transpose to get (bsz, num_codebooks, seq_len)
+ input_ids = input_ids.transpose(1, 2)
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
+ if decoder_start_token_id is None:
+ raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
+ shifted_input_ids[..., 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenSinusoidalPositionalEmbedding with Musicgen->MusicgenMelody
+class MusicgenMelodySinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length."""
+
+ def __init__(self, num_positions: int, embedding_dim: int):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.make_weights(num_positions, embedding_dim)
+
+ def make_weights(self, num_embeddings: int, embedding_dim: int):
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim)
+ if hasattr(self, "weights"):
+ # in forward put the weights on the correct dtype and device of the param
+ emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
+
+ self.register_buffer("weights", emb_weights, persistent=False)
+
+ @staticmethod
+ def get_embedding(num_embeddings: int, embedding_dim: int):
+ """
+ Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
+ description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ return emb.to(torch.get_default_dtype())
+
+ @torch.no_grad()
+ # Ignore copy
+ def forward(self, inputs_embeds: torch.Tensor, past_key_values_length: int = 0):
+ bsz, seq_len, _ = inputs_embeds.size()
+ # Create the position ids from the input token ids.
+ position_ids = (torch.arange(seq_len) + past_key_values_length).to(inputs_embeds.device)
+ # expand embeddings if needed
+ if seq_len > self.weights.size(0):
+ self.make_weights(seq_len, self.embedding_dim)
+ return self.weights.index_select(0, position_ids.view(-1)).detach()
+
+
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->MusicgenMelody
+class MusicgenMelodyAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: Optional[float] = 0.0,
+ is_decoder: Optional[bool] = False,
+ bias: Optional[bool] = True,
+ is_causal: Optional[bool] = False,
+ config: Optional[MusicgenMelodyConfig] = None,
+ layer_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+ self.layer_idx = layer_idx
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
+ is_updated = False
+ if past_key_values is not None:
+ if isinstance(past_key_values, EncoderDecoderCache):
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_layer from cache
+ curr_past_key_value = past_key_values.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_values.self_attention_cache
+ else:
+ curr_past_key_value = past_key_values
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_values is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
+ value_states = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
+
+ if past_key_values is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
+ past_key_values.is_updated[self.layer_idx] = True
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class MusicgenMelodyDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: MusicgenMelodyDecoderConfig, layer_idx=None):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+
+ self.self_attn = MusicgenMelodyAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ bias=False,
+ is_causal=True,
+ config=config,
+ layer_idx=layer_idx,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
+ self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(attention_heads,)`.
+ past_key_values (`Cache`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ return hidden_states, self_attn_weights
+
+
+@auto_docstring
+# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenPreTrainedModel with Musicgen->MusicgenMelody
+class MusicgenMelodyPreTrainedModel(PreTrainedModel):
+ config: MusicgenMelodyDecoderConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_factor
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder with MUSICGEN->MUSICGEN_MELODY,Musicgen->MusicgenMelody
+class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MusicgenMelodyDecoderLayer`]
+ """
+
+ def __init__(self, config: MusicgenMelodyDecoderConfig):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.layerdrop
+ self.max_target_positions = config.max_position_embeddings
+ self.d_model = config.hidden_size
+ self.num_codebooks = config.num_codebooks
+ self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
+
+ embed_dim = config.vocab_size + 1
+ self.embed_tokens = nn.ModuleList(
+ [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
+ )
+
+ self.embed_positions = MusicgenMelodySinusoidalPositionalEmbedding(
+ config.max_position_embeddings,
+ config.hidden_size,
+ )
+
+ self.layers = nn.ModuleList(
+ [MusicgenMelodyDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
+ )
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
+ self.attn_implementation = config._attn_implementation
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ # Ignore copy
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
+
+ Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
+ such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+
+
+ The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
+ target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
+ you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
+ frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
+ target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
+ `input_ids`.
+
+
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states representing the concatenation of the text encoder output and the processed audio encoder output.
+ Used as a conditional signal and will thus be concatenated to the projected `decoder_input_ids`.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing attention on conditional hidden states. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ # (bsz * codebooks, seq_len) -> (bsz, codebooks, seq_len)
+ input = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
+ bsz, num_codebooks, seq_len = input.shape
+ input_shape = (bsz, seq_len)
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1:]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if use_cache and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+ if use_cache and isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
+
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ if inputs_embeds is None:
+ inputs_embeds = sum(self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks))
+
+ if encoder_hidden_states is not None:
+ # take care of attention masks
+ if encoder_attention_mask is not None and attention_mask is None:
+ attention_mask = torch.ones(inputs_embeds.shape[:2], device=inputs_embeds.device)
+
+ if attention_mask is not None:
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=attention_mask.device)
+ attention_mask = torch.cat([encoder_attention_mask, attention_mask], dim=1)
+
+ # fuse encoder_hidden_states and inputs_embeds
+ inputs_embeds = torch.cat([encoder_hidden_states, inputs_embeds], dim=1)
+
+ input_shape = inputs_embeds.size()[:-1]
+
+ attention_mask = self._update_causal_mask(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+
+ # embed positions
+ positions = self.embed_positions(inputs_embeds, past_key_values_length)
+ hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The `head_mask` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ dropout_probability = random.uniform(0, 1)
+ if self.training and (dropout_probability < self.layerdrop):
+ continue
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_attentions += (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_shape),
+ device=inputs_embeds.device,
+ )
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ return attention_mask
+
+ # Ignore copy
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # MusicgenMelody doesn't apply cross attention, hence it's ignored here
+ # and only exists to not confuse any copy checks
+ pass
+
+
+@auto_docstring
+# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenModel with MUSICGEN->MUSICGEN_MELODY,Musicgen->MusicgenMelody
+class MusicgenMelodyModel(MusicgenMelodyPreTrainedModel):
+ def __init__(self, config: MusicgenMelodyDecoderConfig):
+ super().__init__(config)
+ self.decoder = MusicgenMelodyDecoder(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.decoder.embed_tokens = value
+
+ @auto_docstring
+ # Ignore copy
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
+
+ Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
+ such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+
+
+ The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
+ target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
+ you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
+ frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
+ target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
+ `input_ids`.
+
+
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states representing the concatenation of the text encoder output and the processed audio encoder output.
+ Used as a conditional signal and will thus be concatenated to the projected `decoder_input_ids`.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing attention on conditional hidden states. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ if not return_dict:
+ return decoder_outputs
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ hidden_states=decoder_outputs.hidden_states,
+ attentions=decoder_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Musicgen Melody decoder model with a language modelling head on top.
+ """
+)
+# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with MUSICGEN->MUSICGEN_MELODY,Musicgen->MusicgenMelody,MusicGen->Musicgen Melody
+class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel, GenerationMixin):
+ def __init__(self, config: MusicgenMelodyDecoderConfig):
+ super().__init__(config)
+
+ self.model = MusicgenMelodyModel(config)
+
+ self.num_codebooks = config.num_codebooks
+ self.lm_heads = nn.ModuleList(
+ [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)]
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_heads
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_heads = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.decoder = decoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ @auto_docstring
+ # Ignore copy
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, MusicgenMelodyOutputWithPast]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
+
+ Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
+ such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+
+
+ The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
+ target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
+ you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
+ frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
+ target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
+ `input_ids`.
+
+
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states representing the concatenation of the text encoder output and the processed audio encoder output.
+ Used as a conditional signal and will thus be concatenated to the projected `decoder_input_ids`.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing attention on conditional hidden states. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (labels is not None) and (input_ids is None and inputs_embeds is None):
+ input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+
+ lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1)
+
+ loss = None
+ if labels is not None:
+ # since encoder hidden states have been concatenated to the decoder hidden states,
+ # we take the last timestamps corresponding to labels
+ logits = lm_logits[:, :, -labels.shape[1] :]
+
+ loss_fct = CrossEntropyLoss()
+ loss = torch.zeros([], device=self.device)
+
+ # per codebook cross-entropy
+ # ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243
+ # -100 labels are ignored
+ labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
+
+ # per codebook cross-entropy
+ for codebook in range(self.config.num_codebooks):
+ codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
+ codebook_labels = labels[..., codebook].contiguous().view(-1)
+ loss += loss_fct(codebook_logits, codebook_labels)
+
+ loss = loss / self.config.num_codebooks
+
+ # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
+ lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MusicgenMelodyOutputWithPast(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ # Ignore copy
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ head_mask=None,
+ past_key_values=None,
+ use_cache=True,
+ delay_pattern_mask=None,
+ guidance_scale=None,
+ **kwargs,
+ ):
+ # Overwritten -- MusicGen has custom processing
+ if delay_pattern_mask is None:
+ input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
+ input_ids,
+ pad_token_id=self.generation_config.pad_token_id,
+ max_length=self.generation_config.max_length,
+ )
+
+ # apply the delay pattern mask
+ input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
+
+ if guidance_scale is not None and guidance_scale > 1:
+ # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
+ # before sampling)
+ input_ids = input_ids.repeat((2, 1))
+ if attention_mask is not None:
+ attention_mask = attention_mask.repeat((2, 1))
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = torch.concatenate(
+ [encoder_hidden_states, torch.zeros_like(encoder_hidden_states)], dim=0
+ )
+
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = torch.concatenate(
+ encoder_attention_mask, torch.zeros_like(encoder_attention_mask), dim=0
+ )
+
+ if past_key_values is not None:
+ input_ids = input_ids[:, -1:]
+
+ # we only want to use conditional signal in the 1st generation step but keeping the attention mask
+ encoder_hidden_states = None
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_attention_mask": encoder_attention_mask,
+ "head_mask": head_mask,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+
+ def build_delay_pattern_mask(
+ self, input_ids: torch.LongTensor, pad_token_id: int, max_length: Optional[int] = None
+ ):
+ """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
+ one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
+ are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
+ seq_len)`:
+ - [P, -1, -1, -1, -1, P, P, P]
+ - [P, P, -1, -1, -1, -1, P, P]
+ - [P, P, P, -1, -1, -1, -1, P]
+ - [P, P, P, P, -1, -1, -1, -1]
+ where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
+ a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
+ mask is set to the value in the prompt:
+ - [P, a, b, -1, -1, P, P, P]
+ - [P, P, c, d, -1, -1, P, P]
+ - [P, P, P, e, f, -1, -1, P]
+ - [P, P, P, P, g, h, -1, -1]
+ where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
+ tokens in our prediction.
+ """
+ # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
+ input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
+ bsz, num_codebooks, seq_len = input_ids.shape
+
+ max_length = max_length if max_length is not None else self.generation_config.max_length
+ input_ids_shifted = (
+ torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
+ )
+
+ channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks
+ # we only apply the mask if we have a large enough seq len - otherwise we return as is
+ if max_length < 2 * channel_codebooks - 1:
+ return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
+
+ # fill the shifted ids with the prompt entries, offset by the codebook idx
+ for codebook in range(channel_codebooks):
+ if self.config.audio_channels == 1:
+ # mono channel - loop over the codebooks one-by-one
+ input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
+ else:
+ # left/right channels are interleaved in the generated codebooks, so handle one then the other
+ input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
+ input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]
+
+ # construct a pattern mask that indicates the positions of padding tokens for each codebook
+ # first fill the upper triangular part (the EOS padding)
+ delay_pattern = torch.triu(
+ torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1
+ )
+ # then fill the lower triangular part (the BOS padding)
+ delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool))
+
+ if self.config.audio_channels == 2:
+ # for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
+ delay_pattern = delay_pattern.repeat_interleave(2, dim=0)
+
+ mask = ~delay_pattern.to(input_ids.device)
+ input_ids = mask * input_ids_shifted + ~mask * pad_token_id
+
+ # find the first position to start generating - this is the first place we have the -1 token
+ # and will always be in the first codebook (since it has no codebook offset)
+ first_codebook_ids = input_ids[:, 0, :]
+ start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
+ if len(start_ids) > 0:
+ first_start_id = min(start_ids)
+ else:
+ # we have no tokens that need to be filled - return entire matrix of input ids
+ first_start_id = seq_len
+
+ # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
+ pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
+ input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
+ return input_ids, pattern_mask
+
+ @staticmethod
+ def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
+ """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
+ the mask is set to -1, and otherwise setting to the value detailed in the mask."""
+ seq_len = input_ids.shape[-1]
+ decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
+ input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
+ return input_ids
+
+ @torch.no_grad()
+ # Ignore copy
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ synced_gpus: Optional[bool] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ **kwargs,
+ ):
+ """
+
+ Generates sequences of token ids for models with a language modeling head.
+
+
+
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+ For an overview of generation strategies and code examples, check out the [following
+ guide](./generation_strategies).
+
+
+
+ Parameters:
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+ should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which had the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
+ streamer (`BaseStreamer`, *optional*):
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ kwargs (`dict[str, Any]`, *optional*):
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+ Return:
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
+
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GenerateDecoderOnlyOutput`],
+ - [`~generation.GenerateBeamDecoderOnlyOutput`]
+
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GenerateEncoderDecoderOutput`],
+ - [`~generation.GenerateBeamEncoderDecoderOutput`]
+ """
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
+ if generation_config is None:
+ generation_config = self.generation_config
+
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
+ generation_config.validate()
+ self._validate_model_kwargs(model_kwargs.copy())
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
+ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
+
+ # 3. Define model inputs`
+ input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ batch_size = input_ids.shape[0] // self.num_codebooks
+ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
+
+ # 4. Define other model kwargs
+ model_kwargs["use_cache"] = generation_config.use_cache
+ model_kwargs["guidance_scale"] = generation_config.guidance_scale
+
+ if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
+ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
+ input_ids, generation_config, model_kwargs
+ )
+
+ # 5. Prepare `max_length` depending on other stopping criteria.
+ input_ids_length = input_ids.shape[-1]
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
+ generation_config = self._prepare_generated_length(
+ generation_config=generation_config,
+ has_default_max_length=has_default_max_length,
+ has_default_min_length=has_default_min_length,
+ model_input_name=model_input_name,
+ inputs_tensor=input_ids,
+ input_ids_length=input_ids_length,
+ )
+
+ # 6. Prepare `input_ids` which will be used for auto-regressive generation
+ # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Musicgen)
+ input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
+ input_ids,
+ pad_token_id=generation_config._decoder_start_token_tensor,
+ max_length=generation_config.max_length,
+ )
+
+ if streamer is not None:
+ streamer.put(input_ids.cpu())
+
+ # stash the delay mask so that we don't have to recompute it in each forward pass
+ model_kwargs["delay_pattern_mask"] = delay_pattern_mask
+
+ # 7. determine generation mode
+ generation_mode = generation_config.get_generation_mode()
+
+ # 8. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
+ if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
+ logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
+ generation_config.guidance_scale = None
+
+ # 9. prepare distribution pre_processing samplers
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_length,
+ encoder_input_ids=input_ids,
+ prefix_allowed_tokens_fn=None,
+ logits_processor=logits_processor,
+ device=input_ids.device,
+ )
+
+ # 10. prepare stopping criteria
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+
+ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
+ # expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ **model_kwargs,
+ )
+
+ # 11. run sample
+ outputs = self._sample(
+ input_ids,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=synced_gpus,
+ streamer=streamer,
+ **model_kwargs,
+ )
+
+ else:
+ raise ValueError(
+ "Got incompatible mode for generation, should be one of greedy or sampling. "
+ "Ensure that beam search is de-activated by setting `num_beams=1`."
+ )
+
+ if generation_config.return_dict_in_generate:
+ output_ids = outputs.sequences
+ else:
+ output_ids = outputs
+
+ # apply the pattern mask to the final ids
+ output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
+
+ # revert the pattern delay mask by filtering the pad token id
+ output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
+ batch_size, self.num_codebooks, -1
+ )
+
+ if generation_config.return_dict_in_generate:
+ outputs.sequences = output_ids
+ return outputs
+ else:
+ return output_ids
+
+
+@auto_docstring
+class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
+ config: MusicgenMelodyConfig
+ main_input_name = "input_ids"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ def __init__(
+ self,
+ config: MusicgenMelodyConfig = None,
+ text_encoder: Optional[PreTrainedModel] = None,
+ audio_encoder: Optional[PreTrainedModel] = None,
+ decoder: Optional[MusicgenMelodyForCausalLM] = None,
+ ):
+ r"""
+ text_encoder (`PreTrainedModel`, *optional*):
+ The text encoder model that encodes text into hidden states for conditioning.
+ audio_encoder (`PreTrainedModel`, *optional*):
+ The audio encoder model that encodes audio into hidden states for conditioning.
+ decoder (`MusicgenMelodyForCausalLM`, *optional*):
+ The decoder model that generates audio tokens based on conditioning signals.
+ """
+ if config is None and None in (text_encoder, audio_encoder, decoder):
+ raise ValueError(
+ "Either a configuration has to be provided, or all three of text encoder, audio encoder and Musicgen Melody decoder."
+ )
+ if config is None:
+ config = MusicgenMelodyConfig.from_sub_models_config(
+ text_encoder.config, audio_encoder.config, decoder.config
+ )
+ else:
+ if not isinstance(config, self.config_class):
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
+
+ # initialize with config
+ super().__init__(config)
+
+ if text_encoder is None:
+ text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder)
+
+ if audio_encoder is None:
+ audio_encoder = AutoModel.from_config(config.audio_encoder)
+
+ if decoder is None:
+ decoder = MusicgenMelodyForCausalLM._from_config(config.decoder)
+
+ self.text_encoder = text_encoder
+ self.audio_encoder = audio_encoder
+ self.decoder = decoder
+
+ # make sure that the individual model's config refers to the shared config
+ # so that the updates to the config will be synced
+ self.config.text_encoder._attn_implementation = self.text_encoder.config._attn_implementation
+ self.config.audio_encoder._attn_implementation = self.audio_encoder.config._attn_implementation
+ self.config.decoder._attn_implementation = self.decoder.config._attn_implementation
+ self.text_encoder.config = self.config.text_encoder
+ self.audio_encoder.config = self.config.audio_encoder
+ self.decoder.config = self.config.decoder
+
+ # text encoder outputs might need to be projected to different dimension for decoder
+ if self.text_encoder.config.hidden_size != self.decoder.config.hidden_size:
+ self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
+
+ # audio encoder outputs after chroma extraction might need to be projected to different dimension for decoder
+ if self.config.num_chroma != self.decoder.config.hidden_size:
+ self.audio_enc_to_dec_proj = nn.Linear(self.config.num_chroma, self.decoder.config.hidden_size)
+
+ if self.text_encoder.get_output_embeddings() is not None:
+ raise ValueError(
+ f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head"
+ )
+
+ # Initialize projection layers weights and tie text encoder and decoder weights if set accordingly
+ self.post_init()
+
+ def _init_weights(self, module):
+ # MusicgenMelodyForConditionalGeneration is made of PreTrainedModels that have already been initialized
+ # Projection layers still need to be initialized.
+ std = self.decoder.config.initializer_factor
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ def tie_weights(self):
+ # tie text encoder & decoder if needed
+ if self.config.tie_encoder_decoder:
+ # tie text encoder and decoder base model
+ decoder_base_model_prefix = self.decoder.base_model_prefix
+ tied_weights = self._tie_encoder_decoder_weights(
+ self.text_encoder,
+ self.decoder._modules[decoder_base_model_prefix],
+ self.decoder.base_model_prefix,
+ "text_encoder",
+ )
+ # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
+ # attributed not an instance member, therefore modifying it will modify the entire class
+ # Leading to issues on subsequent calls by different tests or subsequent calls.
+ self._dynamic_tied_weights_keys = tied_weights
+
+ def get_text_encoder(self):
+ return self.text_encoder
+
+ def get_encoder(self):
+ # get the text encoder to compute the conditioning hidden-states for generation
+ return self.get_text_encoder()
+
+ def get_input_embeddings(self):
+ return self.text_encoder.get_input_embeddings()
+
+ def get_output_embeddings(self):
+ return self.decoder.get_output_embeddings()
+
+ def set_output_embeddings(self, new_embeddings):
+ return self.decoder.set_output_embeddings(new_embeddings)
+
+ @classmethod
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForConditionalGeneration.from_sub_models_pretrained with Musicgen->MusicgenMelody, musicgen-small->musicgen-melody
+ def from_sub_models_pretrained(
+ cls,
+ text_encoder_pretrained_model_name_or_path: Optional[str] = None,
+ audio_encoder_pretrained_model_name_or_path: Optional[str] = None,
+ decoder_pretrained_model_name_or_path: Optional[str] = None,
+ *model_args,
+ **kwargs,
+ ) -> PreTrainedModel:
+ r"""
+ Instantiate a text encoder, an audio encoder, and a MusicGen decoder from one, two or three base classes of the
+ library from pretrained model checkpoints.
+
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you need to first set it back in training mode with `model.train()`.
+
+ Params:
+ text_encoder_pretrained_model_name_or_path (`str`, *optional*):
+ Information necessary to initiate the text encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ audio_encoder_pretrained_model_name_or_path (`str`, *optional*):
+ Information necessary to initiate the audio encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
+ Information necessary to initiate the decoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ model_args (remaining positional arguments, *optional*):
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`).
+
+ - To update the text encoder configuration, use the prefix *text_encoder_* for each configuration
+ parameter.
+ - To update the audio encoder configuration, use the prefix *audio_encoder_* for each configuration
+ parameter.
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+ Example:
+
+ ```python
+ >>> from transformers import MusicgenMelodyForConditionalGeneration
+
+ >>> # initialize a musicgen model from a t5 text encoder, encodec audio encoder, and musicgen decoder
+ >>> model = MusicgenMelodyForConditionalGeneration.from_sub_models_pretrained(
+ ... text_encoder_pretrained_model_name_or_path="google-t5/t5-base",
+ ... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz",
+ ... decoder_pretrained_model_name_or_path="facebook/musicgen-melody",
+ ... )
+ >>> # saving model after fine-tuning
+ >>> model.save_pretrained("./musicgen-ft")
+ >>> # load fine-tuned model
+ >>> model = MusicgenMelodyForConditionalGeneration.from_pretrained("./musicgen-ft")
+ ```"""
+
+ kwargs_text_encoder = {
+ argument[len("text_encoder_") :]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("text_encoder_")
+ }
+
+ kwargs_audio_encoder = {
+ argument[len("audio_encoder_") :]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("audio_encoder_")
+ }
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ # remove text encoder, audio encoder and decoder kwargs from kwargs
+ for key in kwargs_text_encoder:
+ del kwargs["text_encoder_" + key]
+ for key in kwargs_audio_encoder:
+ del kwargs["audio_encoder_" + key]
+ for key in kwargs_decoder:
+ del kwargs["decoder_" + key]
+
+ # Load and initialize the encoder and decoder
+ # The distinction between encoder and decoder at the model level is made
+ # by the value of the flag `is_decoder` that we need to set correctly.
+ text_encoder = kwargs_text_encoder.pop("model", None)
+ if text_encoder is None:
+ if text_encoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `text_encoder_model` is not defined as an argument, a `text_encoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_text_encoder:
+ encoder_config, kwargs_text_encoder = AutoConfig.from_pretrained(
+ text_encoder_pretrained_model_name_or_path, **kwargs_text_encoder, return_unused_kwargs=True
+ )
+
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+ logger.info(
+ f"Initializing {text_encoder_pretrained_model_name_or_path} as a text_encoder model "
+ "from a decoder model. Cross-attention and causal mask are disabled."
+ )
+ encoder_config.is_decoder = False
+ encoder_config.add_cross_attention = False
+
+ kwargs_text_encoder["config"] = encoder_config
+
+ text_encoder = AutoModel.from_pretrained(
+ text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder
+ )
+
+ audio_encoder = kwargs_audio_encoder.pop("model", None)
+ if audio_encoder is None:
+ if audio_encoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `audio_encoder_model` is not defined as an argument, an `audio_encoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_audio_encoder:
+ encoder_config, kwargs_audio_encoder = AutoConfig.from_pretrained(
+ audio_encoder_pretrained_model_name_or_path, **kwargs_audio_encoder, return_unused_kwargs=True
+ )
+
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+ logger.info(
+ f"Initializing {audio_encoder_pretrained_model_name_or_path} as an audio_encoder model "
+ "from a decoder model. Cross-attention and causal mask are disabled."
+ )
+ encoder_config.is_decoder = False
+ encoder_config.add_cross_attention = False
+
+ kwargs_audio_encoder["config"] = encoder_config
+
+ audio_encoder = AutoModel.from_pretrained(
+ audio_encoder_pretrained_model_name_or_path, *model_args, **kwargs_audio_encoder
+ )
+
+ decoder = kwargs_decoder.pop("model", None)
+ if decoder is None:
+ if decoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_decoder:
+ decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
+ )
+
+ if isinstance(decoder_config, MusicgenMelodyConfig):
+ decoder_config = decoder_config.decoder
+
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
+ logger.info(
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
+ )
+ decoder_config.is_decoder = True
+ decoder_config.add_cross_attention = True
+
+ kwargs_decoder["config"] = decoder_config
+
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
+ logger.warning(
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
+ "passed to `.from_sub_models_pretrained(...)` are set to `True` or do not pass a "
+ "`decoder_config` to `.from_sub_models_pretrained(...)`"
+ )
+
+ decoder = MusicgenMelodyForCausalLM.from_pretrained(
+ decoder_pretrained_model_name_or_path, **kwargs_decoder
+ )
+
+ # instantiate config with corresponding kwargs
+ config = MusicgenMelodyConfig.from_sub_models_config(
+ text_encoder.config, audio_encoder.config, decoder.config, **kwargs
+ )
+ return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ input_features: Optional[torch.FloatTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, MusicgenMelodyOutputWithPast]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
+
+ Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
+ such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+
+
+ The `decoder_input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
+ target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
+ you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
+ frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
+ target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
+ `decoder_input_ids`.
+
+
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of conditional hidden-states representing the concatenation of the projected text encoder output and the projected audio encoder output.
+ Used as a conditional signal and will thus be concatenated to the projected `decoder_input_ids`.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+
+ Examples:
+ ```python
+ >>> from transformers import AutoProcessor, MusicgenMelodyForConditionalGeneration
+ >>> import torch
+
+ >>> processor = AutoProcessor.from_pretrained("facebook/musicgen-melody")
+ >>> model = MusicgenMelodyForConditionalGeneration.from_pretrained("facebook/musicgen-melody")
+
+ >>> inputs = processor(
+ ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
+ ... padding=True,
+ ... return_tensors="pt",
+ ... )
+
+ >>> pad_token_id = model.generation_config.pad_token_id
+ >>> decoder_input_ids = (
+ ... torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)
+ ... * pad_token_id
+ ... )
+
+ >>> logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits
+ >>> logits.shape # (bsz * num_codebooks, encoder_len + tgt_len, vocab_size)
+ torch.Size([8, 249, 2048])
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ kwargs_text_encoder = {
+ argument[len("text_encoder_")]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("text_encoder_")
+ }
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ if encoder_hidden_states is None:
+ if inputs_embeds is not None or input_ids is not None:
+ encoder_outputs = self.text_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ **kwargs_text_encoder,
+ )
+
+ encoder_hidden_states = encoder_outputs[0]
+
+ # optionally project encoder_hidden_states
+ if self.text_encoder.config.hidden_size != self.decoder.config.hidden_size:
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+ if attention_mask is not None and encoder_hidden_states is not None:
+ encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]
+
+ # set a default audio conditional hidden states if text is not None
+ if encoder_hidden_states is not None and input_features is None:
+ input_features = torch.zeros(
+ (encoder_hidden_states.shape[0], 1, self.config.num_chroma),
+ device=self.device,
+ dtype=self.dtype,
+ )
+ input_features[:, :, 0] = 1
+
+ if input_features is not None:
+ audio_hidden_states = input_features
+
+ # optionally project audio_hidden_states ->
+ # (batch_size, seq_len, num_chroma) -> (batch_size, seq_len, hidden_size)
+ if self.config.num_chroma != self.decoder.config.hidden_size:
+ audio_hidden_states = self.audio_enc_to_dec_proj(audio_hidden_states)
+
+ # pad or truncate to config.chroma_length
+ if audio_hidden_states.shape[1] < self.config.chroma_length:
+ n_repeat = int(math.ceil(self.config.chroma_length / audio_hidden_states.shape[1]))
+ audio_hidden_states = audio_hidden_states.repeat(1, n_repeat, 1)
+ else:
+ logger.warning(
+ f"The conditional audio signal is of length {audio_hidden_states.shape[1]}, which exceeds"
+ f"the maximum chroma duration of {self.config.chroma_length}."
+ f"The audio will be truncated to {self.config.chroma_length} frames."
+ )
+ audio_hidden_states = audio_hidden_states[:, : self.config.chroma_length]
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = torch.cat([audio_hidden_states, encoder_hidden_states], dim=1)
+ else:
+ encoder_hidden_states = audio_hidden_states
+
+ if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id
+ )
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ inputs_embeds=decoder_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ past_key_values=past_key_values,
+ return_dict=return_dict,
+ labels=labels,
+ **kwargs_decoder,
+ )
+
+ if not return_dict:
+ return decoder_outputs + (encoder_hidden_states,)
+
+ return MusicgenMelodyOutputWithPast(
+ loss=decoder_outputs.loss,
+ logits=decoder_outputs.logits,
+ past_key_values=decoder_outputs.past_key_values,
+ hidden_states=decoder_outputs.hidden_states,
+ attentions=decoder_outputs.attentions,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ encoder_hidden_states=None,
+ past_key_values=None,
+ attention_mask=None,
+ decoder_attention_mask=None,
+ decoder_head_mask=None,
+ use_cache=None,
+ decoder_delay_pattern_mask=None,
+ guidance_scale=None,
+ **kwargs,
+ ):
+ # Overwritten -- MusicGen has custom processing
+ if decoder_delay_pattern_mask is None:
+ decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
+ decoder_input_ids,
+ self.generation_config.pad_token_id,
+ max_length=self.generation_config.max_length,
+ )
+
+ # apply the delay pattern mask
+ decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask)
+
+ if guidance_scale is not None and guidance_scale > 1:
+ # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
+ # before sampling)
+ decoder_input_ids = decoder_input_ids.repeat((2, 1))
+ if decoder_attention_mask is not None:
+ decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
+
+ if past_key_values is not None:
+ past_length = past_key_values.get_seq_length()
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
+
+ # we only want to use conditional signal in the 1st generation step but keeping the attention mask
+ encoder_hidden_states = None
+ # we also have to update the attention mask
+
+ return {
+ "input_ids": None, # encoder_hidden_states is defined. input_ids not needed
+ "encoder_hidden_states": encoder_hidden_states,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "use_cache": use_cache,
+ }
+
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForConditionalGeneration._prepare_decoder_input_ids_for_generation
+ def _prepare_decoder_input_ids_for_generation(
+ self,
+ batch_size: int,
+ model_input_name: str,
+ model_kwargs: dict[str, torch.Tensor],
+ decoder_start_token_id: Optional[int] = None,
+ bos_token_id: Optional[int] = None,
+ device: Optional[torch.device] = None,
+ ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
+ """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
+
+ # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
+ # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
+ if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
+ decoder_input_ids = model_kwargs.pop("decoder_input_ids")
+ elif "input_ids" in model_kwargs and model_input_name != "input_ids":
+ decoder_input_ids = model_kwargs.pop("input_ids")
+ else:
+ decoder_input_ids = None
+
+ # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
+ decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
+ if device is None:
+ device = self.device
+ decoder_input_ids_start = (
+ torch.ones((batch_size * self.decoder.num_codebooks, 1), dtype=torch.long, device=device)
+ * decoder_start_token_id
+ )
+
+ # no user input -> use decoder_start_token_id as decoder_input_ids
+ if decoder_input_ids is None:
+ decoder_input_ids = decoder_input_ids_start
+
+ # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
+ # decoder_attention_mask if provided)
+ elif (decoder_input_ids[..., 0] != decoder_start_token_id).all().item():
+ decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
+ if "decoder_attention_mask" in model_kwargs:
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+ decoder_attention_mask = torch.cat(
+ (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
+ dim=-1,
+ )
+ model_kwargs["decoder_attention_mask"] = decoder_attention_mask
+
+ return decoder_input_ids, model_kwargs
+
+ def _prepare_encoder_hidden_states_kwargs_for_generation(
+ self,
+ inputs_tensor: torch.Tensor,
+ model_kwargs,
+ model_input_name: Optional[str],
+ generation_config: GenerationConfig,
+ ) -> dict[str, Any]:
+ encoder_hidden_states = None
+ # attention mask is consumed once to produce text conditional hidden states through the text encoder
+ encoder_attention_mask = model_kwargs.pop("attention_mask")
+ guidance_scale = generation_config.guidance_scale
+
+ # 1. condition on text
+ if inputs_tensor is not None:
+ encoder = self.get_text_encoder()
+ # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
+ # as the inputs.
+ if hasattr(encoder, "_hf_hook"):
+ encoder._hf_hook.io_same_device = True
+
+ # Prepare args and kwargs from model kwargs.
+ irrelevant_prefix = ["decoder_", "use_cache"]
+ encoder_kwargs = {
+ argument: value
+ for argument, value in model_kwargs.items()
+ if not any(argument.startswith(p) for p in irrelevant_prefix)
+ }
+ encoder_signature = set(inspect.signature(encoder.forward).parameters)
+ encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
+ if not encoder_accepts_wildcard:
+ encoder_kwargs = {
+ argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
+ }
+ encoder_kwargs["output_attentions"] = generation_config.output_attentions
+ encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+
+ # make sure that encoder returns `ModelOutput`
+ model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name
+ encoder_kwargs["return_dict"] = True
+ encoder_kwargs[model_input_name] = inputs_tensor
+ if encoder_attention_mask is not None:
+ encoder_kwargs["attention_mask"] = encoder_attention_mask
+ encoder_hidden_states = encoder(**encoder_kwargs).last_hidden_state
+
+ # optionally project encoder_hidden_states
+ if self.text_encoder.config.hidden_size != self.decoder.config.hidden_size:
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+ # for classifier free guidance we need to add a 'null' input to our encoder hidden states
+ if guidance_scale is not None and guidance_scale > 1:
+ encoder_hidden_states = torch.concatenate(
+ [encoder_hidden_states, torch.zeros_like(encoder_hidden_states)], dim=0
+ )
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = torch.concatenate(
+ [encoder_attention_mask, torch.zeros_like(encoder_attention_mask)], dim=0
+ )
+ if encoder_attention_mask is not None:
+ encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[..., None]
+
+ # 2. condition on audio
+ audio_hidden_states = model_kwargs.get("input_features", None)
+
+ if inputs_tensor is not None:
+ if audio_hidden_states is not None:
+ null_audio_hidden_states = torch.zeros_like(audio_hidden_states)
+ else:
+ null_audio_hidden_states = torch.zeros(
+ (inputs_tensor.shape[0], 1, self.config.num_chroma), device=self.device, dtype=self.dtype
+ )
+ null_audio_hidden_states[:, :, 0] = 1
+
+ if audio_hidden_states is None:
+ audio_hidden_states = null_audio_hidden_states
+
+ if audio_hidden_states is not None:
+ # for classifier free guidance we need to add a 'null' input to our audio hidden states
+ if guidance_scale is not None and guidance_scale > 1:
+ audio_hidden_states = torch.concatenate([audio_hidden_states, null_audio_hidden_states], dim=0)
+
+ # optionally project audio_hidden_states ->
+ # (batch_size, seq_len, num_chroma) -> (batch_size, seq_len, hidden_size)
+ if self.config.num_chroma != self.decoder.config.hidden_size:
+ audio_hidden_states = self.audio_enc_to_dec_proj(audio_hidden_states)
+
+ # pad or truncate to config.chroma_length
+ if audio_hidden_states.shape[1] < self.config.chroma_length:
+ n_repeat = int(math.ceil(self.config.chroma_length / audio_hidden_states.shape[1]))
+ audio_hidden_states = audio_hidden_states.repeat(1, n_repeat, 1)
+ audio_hidden_states = audio_hidden_states[:, : self.config.chroma_length]
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = torch.cat([audio_hidden_states, encoder_hidden_states], dim=1)
+ else:
+ encoder_hidden_states = audio_hidden_states
+
+ model_kwargs["encoder_hidden_states"] = encoder_hidden_states
+
+ return model_kwargs
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id)
+
+ def resize_token_embeddings(self, *args, **kwargs):
+ raise NotImplementedError(
+ "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
+ " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
+ " model.decoder.resize_token_embeddings(...))"
+ )
+
+ def _maybe_initialize_input_ids_for_generation(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ bos_token_id: Optional[int] = None,
+ model_kwargs: Optional[dict[str, torch.Tensor]] = None,
+ ) -> torch.LongTensor:
+ """Initializes input ids for generation, if necessary."""
+ if inputs is not None:
+ return inputs
+
+ if bos_token_id is None:
+ raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
+
+ # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
+ # soft-prompting or in multimodal implementations built on top of decoder-only language models.
+ batch_size = 1
+ for value in model_kwargs.values():
+ if isinstance(value, torch.Tensor):
+ batch_size = value.shape[0]
+ break
+ return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
+
+ def freeze_audio_encoder(self):
+ """
+ Freeze the audio encoder weights.
+ """
+ for param in self.audio_encoder.parameters():
+ param.requires_grad = False
+ self.audio_encoder._requires_grad = False
+
+ def freeze_text_encoder(self):
+ """
+ Freeze the text encoder weights.
+ """
+ for param in self.text_encoder.parameters():
+ param.requires_grad = False
+ self.text_encoder._requires_grad = False
+
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForConditionalGeneration._get_decoder_start_token_id
+ def _get_decoder_start_token_id(
+ self, decoder_start_token_id: Optional[Union[int, list[int]]] = None, bos_token_id: Optional[int] = None
+ ) -> int:
+ decoder_start_token_id = (
+ decoder_start_token_id
+ if decoder_start_token_id is not None
+ else self.generation_config.decoder_start_token_id
+ )
+ bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
+
+ if decoder_start_token_id is not None:
+ return decoder_start_token_id
+ elif bos_token_id is not None:
+ return bos_token_id
+ raise ValueError(
+ "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ synced_gpus: Optional[bool] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ **kwargs,
+ ):
+ """
+
+ Generates sequences of token ids for models with a language modeling head.
+
+
+
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+ For an overview of generation strategies and code examples, check out the [following
+ guide](./generation_strategies).
+
+
+
+ Parameters:
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+ should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which had the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
+ streamer (`BaseStreamer`, *optional*):
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ kwargs (`dict[str, Any]`, *optional*):
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+ Return:
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
+
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GenerateDecoderOnlyOutput`],
+ - [`~generation.GenerateBeamDecoderOnlyOutput`]
+
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GenerateEncoderDecoderOutput`],
+ - [`~generation.GenerateBeamEncoderDecoderOutput`]
+ """
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
+ if generation_config is None:
+ generation_config = self.generation_config
+
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
+ generation_config.validate()
+ self._validate_model_kwargs(model_kwargs.copy())
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
+ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
+
+ # 3. Define model inputs
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ batch_size = inputs_tensor.shape[0]
+ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
+
+ # 4. Define other model kwargs
+ model_kwargs["use_cache"] = generation_config.use_cache
+ model_kwargs["guidance_scale"] = generation_config.guidance_scale
+
+ if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
+ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
+ inputs_tensor, generation_config, model_kwargs
+ )
+
+ if "encoder_hidden_states" not in model_kwargs:
+ # encoder_hidden_states are created and added to `model_kwargs`
+ model_kwargs = self._prepare_encoder_hidden_states_kwargs_for_generation(
+ inputs_tensor, model_kwargs, model_input_name, generation_config
+ )
+
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
+ batch_size=batch_size,
+ model_input_name=model_input_name,
+ model_kwargs=model_kwargs,
+ decoder_start_token_id=generation_config._decoder_start_token_tensor,
+ bos_token_id=generation_config._bos_token_tensor,
+ device=inputs_tensor.device,
+ )
+
+ # 6. Prepare `max_length` depending on other stopping criteria.
+ input_ids_length = input_ids.shape[-1]
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
+ generation_config = self._prepare_generated_length(
+ generation_config=generation_config,
+ has_default_max_length=has_default_max_length,
+ has_default_min_length=has_default_min_length,
+ model_input_name=model_input_name,
+ inputs_tensor=inputs_tensor,
+ input_ids_length=input_ids_length,
+ )
+
+ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
+
+ # 7. Prepare the cache.
+ # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
+ # - different models have a different cache name expected by the model (default = "past_key_values")
+ # - `max_length`, prepared above, is used to determine the maximum cache length
+ max_cache_length = generation_config.max_length - 1
+ if (
+ inputs_tensor.shape[1] != input_ids_length
+ and model_input_name == "inputs_embeds"
+ and not self.config.is_encoder_decoder
+ ):
+ max_cache_length += inputs_tensor.shape[1]
+ self._prepare_cache_for_generation(
+ generation_config,
+ model_kwargs,
+ generation_mode=None,
+ batch_size=batch_size,
+ max_cache_length=max_cache_length,
+ )
+
+ # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
+ input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
+ input_ids,
+ pad_token_id=generation_config._decoder_start_token_tensor,
+ max_length=generation_config.max_length,
+ )
+ # stash the delay mask so that we don't have to recompute in each forward pass
+ model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask
+
+ # input_ids are ready to be placed on the streamer (if used)
+ if streamer is not None:
+ streamer.put(input_ids.cpu())
+
+ # 8. determine generation mode
+ generation_mode = generation_config.get_generation_mode()
+
+ # 9. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
+ if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
+ logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
+ generation_config.guidance_scale = None
+
+ # 10. prepare distribution pre_processing samplers
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_length,
+ encoder_input_ids=inputs_tensor,
+ prefix_allowed_tokens_fn=None,
+ logits_processor=logits_processor,
+ device=input_ids.device,
+ )
+
+ # 10. prepare stopping criteria
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+
+ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
+ # expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ # 11. run sample
+ outputs = self._sample(
+ input_ids,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=synced_gpus,
+ streamer=streamer,
+ **model_kwargs,
+ )
+
+ else:
+ raise ValueError(
+ "Got incompatible mode for generation, should be one of greedy or sampling. "
+ "Ensure that beam search is de-activated by setting `num_beams=1`."
+ )
+
+ if generation_config.return_dict_in_generate:
+ output_ids = outputs.sequences
+ else:
+ output_ids = outputs
+
+ # apply the pattern mask to the final ids
+ output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
+
+ # revert the pattern delay mask by filtering the pad token id
+ output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
+ batch_size, self.decoder.num_codebooks, -1
+ )
+
+ # append the frame dimension back to the audio codes
+ output_ids = output_ids[None, ...]
+
+ audio_scales = model_kwargs.get("audio_scales")
+ if audio_scales is None:
+ audio_scales = [None] * batch_size
+
+ if self.decoder.config.audio_channels == 1:
+ output_values = self.audio_encoder.decode(
+ output_ids,
+ audio_scales=audio_scales,
+ ).audio_values
+ else:
+ codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales)
+ output_values_left = codec_outputs_left.audio_values
+
+ codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales)
+ output_values_right = codec_outputs_right.audio_values
+
+ output_values = torch.cat([output_values_left, output_values_right], dim=1)
+
+ if generation_config.return_dict_in_generate:
+ outputs.sequences = output_values
+ return outputs
+ else:
+ return output_values
+
+
+__all__ = [
+ "MusicgenMelodyForConditionalGeneration",
+ "MusicgenMelodyForCausalLM",
+ "MusicgenMelodyModel",
+ "MusicgenMelodyPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/processing_musicgen_melody.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/processing_musicgen_melody.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c2bbcb6e4a8ad7572f8ad507a71bde18d981711
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/musicgen_melody/processing_musicgen_melody.py
@@ -0,0 +1,139 @@
+# coding=utf-8
+# Copyright 2024 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Text/audio processor class for MusicGen Melody
+"""
+
+from typing import Any
+
+import numpy as np
+
+from ...processing_utils import ProcessorMixin
+from ...utils import to_numpy
+from ...utils.import_utils import requires
+
+
+@requires(backends=("torchaudio",))
+class MusicgenMelodyProcessor(ProcessorMixin):
+ r"""
+ Constructs a MusicGen Melody processor which wraps a Wav2Vec2 feature extractor - for raw audio waveform processing - and a T5 tokenizer into a single processor
+ class.
+
+ [`MusicgenProcessor`] offers all the functionalities of [`MusicgenMelodyFeatureExtractor`] and [`T5Tokenizer`]. See
+ [`~MusicgenProcessor.__call__`] and [`~MusicgenProcessor.decode`] for more information.
+
+ Args:
+ feature_extractor (`MusicgenMelodyFeatureExtractor`):
+ An instance of [`MusicgenMelodyFeatureExtractor`]. The feature extractor is a required input.
+ tokenizer (`T5Tokenizer`):
+ An instance of [`T5Tokenizer`]. The tokenizer is a required input.
+ """
+
+ feature_extractor_class = "MusicgenMelodyFeatureExtractor"
+ tokenizer_class = ("T5Tokenizer", "T5TokenizerFast")
+
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+
+ # Copied from transformers.models.musicgen.processing_musicgen.MusicgenProcessor.get_decoder_prompt_ids
+ def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
+ return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
+
+ def __call__(self, *args, **kwargs):
+ """
+ Forwards the `audio` argument to EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`] and the `text`
+ argument to [`~T5Tokenizer.__call__`]. Please refer to the docstring of the above two methods for more
+ information.
+ """
+
+ if len(args) > 0:
+ kwargs["audio"] = args[0]
+ return super().__call__(*args, **kwargs)
+
+ # Copied from transformers.models.musicgen.processing_musicgen.MusicgenProcessor.batch_decode with padding_mask->attention_mask
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method is used to decode either batches of audio outputs from the MusicGen model, or batches of token ids
+ from the tokenizer. In the case of decoding token ids, this method forwards all its arguments to T5Tokenizer's
+ [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information.
+ """
+ audio_values = kwargs.pop("audio", None)
+ attention_mask = kwargs.pop("attention_mask", None)
+
+ if len(args) > 0:
+ audio_values = args[0]
+ args = args[1:]
+
+ if audio_values is not None:
+ return self._decode_audio(audio_values, attention_mask=attention_mask)
+ else:
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ # Copied from transformers.models.musicgen.processing_musicgen.MusicgenProcessor._decode_audio with padding_mask->attention_mask
+ def _decode_audio(self, audio_values, attention_mask: Any = None) -> list[np.ndarray]:
+ """
+ This method strips any padding from the audio values to return a list of numpy audio arrays.
+ """
+ audio_values = to_numpy(audio_values)
+ bsz, channels, seq_len = audio_values.shape
+
+ if attention_mask is None:
+ return list(audio_values)
+
+ attention_mask = to_numpy(attention_mask)
+
+ # match the sequence length of the padding mask to the generated audio arrays by padding with the **non-padding**
+ # token (so that the generated audio values are **not** treated as padded tokens)
+ difference = seq_len - attention_mask.shape[-1]
+ padding_value = 1 - self.feature_extractor.padding_value
+ attention_mask = np.pad(attention_mask, ((0, 0), (0, difference)), "constant", constant_values=padding_value)
+
+ audio_values = audio_values.tolist()
+ for i in range(bsz):
+ sliced_audio = np.asarray(audio_values[i])[
+ attention_mask[i][None, :] != self.feature_extractor.padding_value
+ ]
+ audio_values[i] = sliced_audio.reshape(channels, -1)
+
+ return audio_values
+
+ def get_unconditional_inputs(self, num_samples=1, return_tensors="pt"):
+ """
+ Helper function to get null inputs for unconditional generation, enabling the model to be used without the
+ feature extractor or tokenizer.
+
+ Args:
+ num_samples (int, *optional*):
+ Number of audio samples to unconditionally generate.
+
+ Example:
+ ```python
+ >>> from transformers import MusicgenMelodyForConditionalGeneration, MusicgenMelodyProcessor
+
+ >>> model = MusicgenMelodyForConditionalGeneration.from_pretrained("facebook/musicgen-melody")
+
+ >>> # get the unconditional (or 'null') inputs for the model
+ >>> processor = MusicgenMelodyProcessor.from_pretrained("facebook/musicgen-melody")
+ >>> unconditional_inputs = processor.get_unconditional_inputs(num_samples=1)
+
+ >>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256)
+ ```"""
+ inputs = self.tokenizer([""] * num_samples, return_tensors=return_tensors, return_attention_mask=True)
+ inputs["attention_mask"][:] = 0
+
+ return inputs
+
+
+__all__ = ["MusicgenMelodyProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/myt5/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/myt5/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..65c8190ee6d94d7f63727202392a625d04aecdb4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/myt5/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .tokenization_myt5 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/myt5/tokenization_myt5.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/myt5/tokenization_myt5.py
new file mode 100644
index 0000000000000000000000000000000000000000..251e3d602b993d8335927cdbeed5518800f2bfb1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/myt5/tokenization_myt5.py
@@ -0,0 +1,380 @@
+# coding=utf-8
+# Copyright 2024
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for model MyT5."""
+
+import json
+import os
+import warnings
+from collections import defaultdict
+from typing import Optional, Union
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+VOCAB_FILES_NAMES = {"vocab_file": "byte_maps.json"}
+
+
+class ByteRewriter:
+ """
+ Byte rewriter class for MyT5 tokenizer.
+ This class is used to rewrite bytes using a hash tree. The hash tree is constructed from a set of rewriting rules.
+
+ Args:
+ rewriting_rules (`str` or `dict[str, str]`):
+ A path to a json file containing the rewriting rules or a dictionary containing the rewriting rules.
+
+ """
+
+ LEAF = "[LEAF]"
+
+ def __init__(self, rewriting_rules: Union[str, dict[str, str]]):
+ if isinstance(rewriting_rules, str):
+ with open(rewriting_rules, "r") as f:
+ rewriting_rules = json.load(f)
+ elif not isinstance(rewriting_rules, dict):
+ raise TypeError(
+ f"rewriting_rules should be either a path to json file or a dict, got {type(rewriting_rules)}"
+ )
+
+ self.hash_tree = self.construct_hash_tree(rewriting_rules)
+ reverse_rewriting_rules = {v: k for k, v in rewriting_rules.items()}
+ self.reverse_hash_tree = self.construct_hash_tree(reverse_rewriting_rules)
+
+ def add_leaf(self, hash_tree: dict[str, Union[dict, list[str]]], byte_in_sequence: str, byte_out_sequence: str):
+ """
+ Add a leaf with the output byte sequence to the hash tree.
+ """
+ byte_in_list = byte_in_sequence.split(" ")
+ byte_out_list = byte_out_sequence.split(" ")
+
+ tree_pointer = hash_tree
+ for b in byte_in_list:
+ if b not in tree_pointer:
+ tree_pointer[b] = {}
+ tree_pointer = tree_pointer[b]
+
+ tree_pointer[self.LEAF] = byte_out_list
+
+ def construct_hash_tree(self, rewriting_rules: dict[str, str]) -> dict[str, Union[dict, list[str]]]:
+ """
+ Construct a hash tree for rewritten byte sequences.
+ """
+ hash_tree = defaultdict(dict)
+ for b in (f"{x:02x}" for x in range(256)):
+ hash_tree[b][self.LEAF] = [b]
+
+ for in_sequence, out_sequence in rewriting_rules.items():
+ self.add_leaf(hash_tree, in_sequence, out_sequence)
+
+ return hash_tree
+
+ def search_hash_tree(self, byte_sequence: list[str]) -> Union[None, list[str]]:
+ """
+ Search the hash tree and return the rewritten byte sequence if found.
+ """
+ tree_pointer = self.hash_tree
+ for b in byte_sequence:
+ if b in tree_pointer:
+ tree_pointer = tree_pointer[b]
+ else:
+ return None
+
+ return tree_pointer[self.LEAF]
+
+ def rewrite_bytes(self, in_bytes: list[str], reverse=False) -> list[str]:
+ """
+ Rewrite a sequence of bytes using the hash tree.
+
+ Args:
+ in_bytes (`list[str]`): A list of bytes to be rewritten.
+ reverse (`bool`): If True, decoding is performed with the reverse hash tree.
+ Returns:
+ `list[str]`: The rewritten byte sequence.
+ """
+ out_bytes = []
+ b_start = 0
+ b_end = 0
+
+ while b_start < len(in_bytes):
+ tree_pointer = self.hash_tree if not reverse else self.reverse_hash_tree
+ for j in range(b_start, len(in_bytes)):
+ b = in_bytes[j]
+ if b in tree_pointer:
+ tree_pointer = tree_pointer[b]
+ elif j == b_start:
+ cur_leaf = [b]
+ b_end = j
+ break
+ else:
+ break
+ if self.LEAF in tree_pointer:
+ cur_leaf = tree_pointer[self.LEAF]
+ b_end = j
+ out_bytes.extend(cur_leaf)
+ b_start = b_end + 1
+
+ return out_bytes
+
+
+class MyT5Tokenizer(PreTrainedTokenizer):
+ """
+ Construct a MyT5 tokenizer.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`): The file containing the byte rewriting rules.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ extra_ids (`int`, *optional*, defaults to 125):
+ Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
+ accessible as "" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
+ indexed from the end of the vocabulary up to beginning ("" is the last token in the vocabulary
+ like in ByT5 preprocessing see
+ [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
+ additional_special_tokens (`list[str]`, *optional*):
+ Additional special tokens used by the tokenizer.
+ """
+
+ model_input_names = ["input_ids", "attention_mask"]
+ vocab_files_names = VOCAB_FILES_NAMES
+
+ def __init__(
+ self,
+ vocab_file,
+ eos_token="",
+ unk_token="",
+ pad_token="",
+ extra_ids=125,
+ additional_special_tokens=None,
+ **kwargs,
+ ) -> None:
+ # Add extra_ids to the special token list
+ if extra_ids > 0 and additional_special_tokens is None:
+ additional_special_tokens = [f"" for i in range(extra_ids)]
+ elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0:
+ # Check that we have the right number of extra_id special tokens
+ extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
+ if extra_tokens != extra_ids:
+ raise ValueError(
+ f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
+ " provided to MyT5Tokenizer. In this case the additional_special_tokens must include the"
+ " extra_ids tokens"
+ )
+
+ pad_token = AddedToken(pad_token, lstrip=True, rstrip=True) if isinstance(pad_token, str) else pad_token
+ eos_token = AddedToken(eos_token, lstrip=True, rstrip=True) if isinstance(eos_token, str) else eos_token
+ unk_token = AddedToken(unk_token, lstrip=True, rstrip=True) if isinstance(unk_token, str) else unk_token
+ # unk token needs to be in the vocab with correct index
+ self._added_tokens_decoder = {0: pad_token, 1: eos_token, 2: unk_token}
+ self.offset = len(self._added_tokens_decoder)
+ self._utf_vocab_size = 2**8 # utf is 8 bits
+
+ # Load byte maps
+ self.byte_maps = json.load(open(vocab_file, "r"))
+
+ self.decompose_rewriter = ByteRewriter(self.byte_maps["decompose_map"])
+ self.merge_rewriter = ByteRewriter(self.byte_maps["merge_map"])
+
+ super().__init__(
+ eos_token=eos_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ extra_ids=0,
+ additional_special_tokens=additional_special_tokens,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self):
+ return self._utf_vocab_size
+
+ # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.get_vocab
+ def get_vocab(self):
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.get_special_tokens_mask
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ # normal case: some special tokens
+ if token_ids_1 is None:
+ return ([0] * len(token_ids_0)) + [1]
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+
+ def _add_eos_if_not_present(self, token_ids: list[int]) -> list[int]:
+ """Do not add eos again if user already added it."""
+ if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
+ warnings.warn(
+ f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
+ " eos tokens being added."
+ )
+ return token_ids
+ else:
+ return token_ids + [self.eos_token_id]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. MyT5 does not
+ make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of zeros.
+ """
+ eos = [self.eos_token_id]
+
+ if token_ids_1 is None:
+ return len(token_ids_0 + eos) * [0]
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
+
+ # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.build_inputs_with_special_tokens
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A sequence has the following format:
+
+ - single sequence: `X `
+ - pair of sequences: `A B `
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ token_ids_0 = self._add_eos_if_not_present(token_ids_0)
+ if token_ids_1 is None:
+ return token_ids_0
+ else:
+ token_ids_1 = self._add_eos_if_not_present(token_ids_1)
+ return token_ids_0 + token_ids_1
+
+ def _tokenize(self, text: str, **kwargs) -> list[str]:
+ """Take as input a string and return a list of strings (tokens) for words/sub-words.
+ Represents tokens in two character hex format"""
+
+ tokens = [f"{i:02x}" for i in text.encode("utf-8")]
+ tokens = self.morphological_encode(tokens)
+ return tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+
+ if len(token) != 2:
+ token_id = None
+ else:
+ token_id = int(token, 16) + self.offset
+
+ return token_id
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ token = f"{index - self.offset:02x}"
+ return token
+
+ def morphological_encode(self, indices: list[str]) -> list[str]:
+ # Decompose and merge morphological sequences
+ indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=False)
+ indices = self.merge_rewriter.rewrite_bytes(indices, reverse=False)
+ return indices
+
+ def morphological_decode(self, indices: list[str]) -> list[str]:
+ # Demerge and compose morphological sequences
+ indices = self.merge_rewriter.rewrite_bytes(indices, reverse=True)
+ indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=True)
+ return indices
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ bstring = b""
+
+ out_tokens = []
+ for token in tokens:
+ if token in self.added_tokens_decoder:
+ out_tokens.append(self.added_tokens_decoder[token])
+ elif token in self.added_tokens_encoder:
+ out_tokens.append(token)
+ else:
+ out_tokens.append(token)
+
+ out_tokens = self.morphological_decode(out_tokens)
+ _added_tokens = set(self.added_tokens_decoder.values()) | set(self.added_tokens_encoder)
+ for token in out_tokens:
+ if token in _added_tokens:
+ bstring += bytes(token, "utf-8")
+ else:
+ bstring += bytes.fromhex(token)
+ string = bstring.decode("utf-8", errors="ignore")
+ return string
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ else:
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ writer.write(json.dumps(self.byte_maps, indent=2, ensure_ascii=False))
+ return (vocab_file,)
+
+
+__all__ = ["MyT5Tokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/nllb/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/nllb/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cdb326098a30020f0d9d2cc4c083090bce70d5a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/nllb/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .tokenization_nllb import *
+ from .tokenization_nllb_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/nllb/tokenization_nllb.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/nllb/tokenization_nllb.py
new file mode 100644
index 0000000000000000000000000000000000000000..4962a642bb3181d0d0e918530681ba7ff40ef70e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/nllb/tokenization_nllb.py
@@ -0,0 +1,394 @@
+# coding=utf-8
+# Copyright 2022 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from shutil import copyfile
+from typing import Any, Optional
+
+import sentencepiece as spm
+
+from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer
+from ...utils import logging
+from ...utils.import_utils import requires
+
+
+logger = logging.get_logger(__name__)
+
+SPIECE_UNDERLINE = "▁"
+
+VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
+
+
+FAIRSEQ_LANGUAGE_CODES = ['ace_Arab', 'ace_Latn', 'acm_Arab', 'acq_Arab', 'aeb_Arab', 'afr_Latn', 'ajp_Arab', 'aka_Latn', 'amh_Ethi', 'apc_Arab', 'arb_Arab', 'ars_Arab', 'ary_Arab', 'arz_Arab', 'asm_Beng', 'ast_Latn', 'awa_Deva', 'ayr_Latn', 'azb_Arab', 'azj_Latn', 'bak_Cyrl', 'bam_Latn', 'ban_Latn', 'bel_Cyrl', 'bem_Latn', 'ben_Beng', 'bho_Deva', 'bjn_Arab', 'bjn_Latn', 'bod_Tibt', 'bos_Latn', 'bug_Latn', 'bul_Cyrl', 'cat_Latn', 'ceb_Latn', 'ces_Latn', 'cjk_Latn', 'ckb_Arab', 'crh_Latn', 'cym_Latn', 'dan_Latn', 'deu_Latn', 'dik_Latn', 'dyu_Latn', 'dzo_Tibt', 'ell_Grek', 'eng_Latn', 'epo_Latn', 'est_Latn', 'eus_Latn', 'ewe_Latn', 'fao_Latn', 'pes_Arab', 'fij_Latn', 'fin_Latn', 'fon_Latn', 'fra_Latn', 'fur_Latn', 'fuv_Latn', 'gla_Latn', 'gle_Latn', 'glg_Latn', 'grn_Latn', 'guj_Gujr', 'hat_Latn', 'hau_Latn', 'heb_Hebr', 'hin_Deva', 'hne_Deva', 'hrv_Latn', 'hun_Latn', 'hye_Armn', 'ibo_Latn', 'ilo_Latn', 'ind_Latn', 'isl_Latn', 'ita_Latn', 'jav_Latn', 'jpn_Jpan', 'kab_Latn', 'kac_Latn', 'kam_Latn', 'kan_Knda', 'kas_Arab', 'kas_Deva', 'kat_Geor', 'knc_Arab', 'knc_Latn', 'kaz_Cyrl', 'kbp_Latn', 'kea_Latn', 'khm_Khmr', 'kik_Latn', 'kin_Latn', 'kir_Cyrl', 'kmb_Latn', 'kon_Latn', 'kor_Hang', 'kmr_Latn', 'lao_Laoo', 'lvs_Latn', 'lij_Latn', 'lim_Latn', 'lin_Latn', 'lit_Latn', 'lmo_Latn', 'ltg_Latn', 'ltz_Latn', 'lua_Latn', 'lug_Latn', 'luo_Latn', 'lus_Latn', 'mag_Deva', 'mai_Deva', 'mal_Mlym', 'mar_Deva', 'min_Latn', 'mkd_Cyrl', 'plt_Latn', 'mlt_Latn', 'mni_Beng', 'khk_Cyrl', 'mos_Latn', 'mri_Latn', 'zsm_Latn', 'mya_Mymr', 'nld_Latn', 'nno_Latn', 'nob_Latn', 'npi_Deva', 'nso_Latn', 'nus_Latn', 'nya_Latn', 'oci_Latn', 'gaz_Latn', 'ory_Orya', 'pag_Latn', 'pan_Guru', 'pap_Latn', 'pol_Latn', 'por_Latn', 'prs_Arab', 'pbt_Arab', 'quy_Latn', 'ron_Latn', 'run_Latn', 'rus_Cyrl', 'sag_Latn', 'san_Deva', 'sat_Beng', 'scn_Latn', 'shn_Mymr', 'sin_Sinh', 'slk_Latn', 'slv_Latn', 'smo_Latn', 'sna_Latn', 'snd_Arab', 'som_Latn', 'sot_Latn', 'spa_Latn', 'als_Latn', 'srd_Latn', 'srp_Cyrl', 'ssw_Latn', 'sun_Latn', 'swe_Latn', 'swh_Latn', 'szl_Latn', 'tam_Taml', 'tat_Cyrl', 'tel_Telu', 'tgk_Cyrl', 'tgl_Latn', 'tha_Thai', 'tir_Ethi', 'taq_Latn', 'taq_Tfng', 'tpi_Latn', 'tsn_Latn', 'tso_Latn', 'tuk_Latn', 'tum_Latn', 'tur_Latn', 'twi_Latn', 'tzm_Tfng', 'uig_Arab', 'ukr_Cyrl', 'umb_Latn', 'urd_Arab', 'uzn_Latn', 'vec_Latn', 'vie_Latn', 'war_Latn', 'wol_Latn', 'xho_Latn', 'ydd_Hebr', 'yor_Latn', 'yue_Hant', 'zho_Hans', 'zho_Hant', 'zul_Latn'] # fmt: skip
+
+
+@requires(backends=("sentencepiece",))
+class NllbTokenizer(PreTrainedTokenizer):
+ """
+ Construct an NLLB tokenizer.
+
+ Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
+ [SentencePiece](https://github.com/google/sentencepiece).
+
+ The tokenization method is ` ` for source language documents, and `
+ ` for target language documents.
+
+ Examples:
+
+ ```python
+ >>> from transformers import NllbTokenizer
+
+ >>> tokenizer = NllbTokenizer.from_pretrained(
+ ... "facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="fra_Latn"
+ ... )
+ >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
+ >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
+ >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
+ ```
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenizer_file (`str`, *optional*):
+ The path to a tokenizer file to use instead of the vocab file.
+ src_lang (`str`, *optional*):
+ The language to use as source language for translation.
+ tgt_lang (`str`, *optional*):
+ The language to use as target language for translation.
+ sp_model_kwargs (`dict[str, str]`):
+ Additional keyword arguments to pass to the model initialization.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ prefix_tokens: list[int] = []
+ suffix_tokens: list[int] = []
+
+ def __init__(
+ self,
+ vocab_file,
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ tokenizer_file=None,
+ src_lang=None,
+ tgt_lang=None,
+ sp_model_kwargs: Optional[dict[str, Any]] = None,
+ additional_special_tokens=None,
+ legacy_behaviour=False,
+ **kwargs,
+ ):
+ if additional_special_tokens is None:
+ additional_special_tokens = FAIRSEQ_LANGUAGE_CODES
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
+ pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = (
+ AddedToken(mask_token, normalized=True, lstrip=True, special=True)
+ if isinstance(mask_token, str)
+ else mask_token
+ )
+
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+ self.legacy_behaviour = legacy_behaviour
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(str(vocab_file))
+ self.vocab_file = vocab_file
+ # Original fairseq vocab and spm vocab must be "aligned":
+ # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
+ # -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ----
+ # fairseq | '' | '' | '' | '' | 'an' | '▁n' | '▁m' | '▁t' | '▁k' | '▁a'
+ # spm | '' | '' | '' | 'an' | '▁n' | '▁m' | '▁t' | '▁k' | '▁a' | '▁s'
+
+ # unk token needs to be in the vocab with correct index
+ self._added_tokens_decoder = {0: bos_token, 1: pad_token, 2: eos_token, 3: unk_token}
+ # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
+ self.fairseq_offset = 1
+ self.sp_model_size = len(self.sp_model)
+
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ tokenizer_file=tokenizer_file,
+ src_lang=src_lang,
+ tgt_lang=tgt_lang,
+ additional_special_tokens=additional_special_tokens,
+ sp_model_kwargs=self.sp_model_kwargs,
+ legacy_behaviour=legacy_behaviour,
+ **kwargs,
+ )
+
+ self._src_lang = src_lang if src_lang is not None else "eng_Latn"
+ self.cur_lang_code_id = self.convert_tokens_to_ids(self._src_lang)
+ self.tgt_lang = tgt_lang
+ self.set_src_lang_special_tokens(self._src_lang)
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+
+ # for backward compatibility
+ if not hasattr(self, "sp_model_kwargs"):
+ self.sp_model_kwargs = {}
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
+
+ @property
+ def vocab_size(self):
+ return len(self.sp_model) + self.fairseq_offset
+
+ @property
+ def src_lang(self) -> str:
+ return self._src_lang
+
+ @src_lang.setter
+ def src_lang(self, new_src_lang: str) -> None:
+ self._src_lang = new_src_lang
+ self.set_src_lang_special_tokens(self._src_lang)
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ prefix_ones = [1] * len(self.prefix_tokens)
+ suffix_ones = [1] * len(self.suffix_tokens)
+ if token_ids_1 is None:
+ return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
+ return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. An NLLB sequence has the following format, where `X` represents the sequence:
+
+ - `input_ids` (for encoder) `X [eos, src_lang_code]`
+ - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
+
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+ separator.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return self.prefix_tokens + token_ids_0 + self.suffix_tokens
+ # We don't expect to process pairs, but leave the pair logic for API consistency
+ return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not
+ make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of zeros.
+
+ """
+
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def _build_translation_inputs(
+ self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
+ ):
+ """Used by translation pipeline, to prepare inputs for the generate function"""
+ if src_lang is None or tgt_lang is None:
+ raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
+ self.src_lang = src_lang
+ inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
+ tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
+ inputs["forced_bos_token_id"] = tgt_lang_id
+ return inputs
+
+ def get_vocab(self):
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text: str) -> list[str]:
+ return self.sp_model.encode(text, out_type=str)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ spm_id = self.sp_model.PieceToId(token)
+ # Need to return unknown token if the SP model returned 0
+ return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.sp_model.IdToPiece(index - self.fairseq_offset)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (strings for sub-words) in a single string."""
+ out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
+ return out_string
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+ def prepare_seq2seq_batch(
+ self,
+ src_texts: list[str],
+ src_lang: str = "eng_Latn",
+ tgt_texts: Optional[list[str]] = None,
+ tgt_lang: str = "fra_Latn",
+ **kwargs,
+ ) -> BatchEncoding:
+ self.src_lang = src_lang
+ self.tgt_lang = tgt_lang
+ return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
+
+ def _switch_to_input_mode(self):
+ return self.set_src_lang_special_tokens(self.src_lang)
+
+ def _switch_to_target_mode(self):
+ return self.set_tgt_lang_special_tokens(self.tgt_lang)
+
+ def set_src_lang_special_tokens(self, src_lang) -> None:
+ """Reset the special tokens to the source lang setting.
+ - In legacy mode: No prefix and suffix=[eos, src_lang_code].
+ - In default mode: Prefix=[src_lang_code], suffix = [eos]
+ """
+ self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
+ if self.legacy_behaviour:
+ self.prefix_tokens = []
+ self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
+ else:
+ self.prefix_tokens = [self.cur_lang_code]
+ self.suffix_tokens = [self.eos_token_id]
+
+ def set_tgt_lang_special_tokens(self, lang: str) -> None:
+ """Reset the special tokens to the target lang setting.
+ - In legacy mode: No prefix and suffix=[eos, tgt_lang_code].
+ - In default mode: Prefix=[tgt_lang_code], suffix = [eos]
+ """
+ self.cur_lang_code = self.convert_tokens_to_ids(lang)
+ if self.legacy_behaviour:
+ self.prefix_tokens = []
+ self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
+ else:
+ self.prefix_tokens = [self.cur_lang_code]
+ self.suffix_tokens = [self.eos_token_id]
+
+
+__all__ = ["NllbTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/nllb/tokenization_nllb_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/nllb/tokenization_nllb_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..5300b3942b5d261e5e83be8c55f1a41da3427b09
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/nllb/tokenization_nllb_fast.py
@@ -0,0 +1,327 @@
+# coding=utf-8
+# Copyright 2022 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from shutil import copyfile
+from typing import Optional
+
+from tokenizers import processors
+
+from ...tokenization_utils import AddedToken, BatchEncoding
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import is_sentencepiece_available, logging
+
+
+if is_sentencepiece_available():
+ from .tokenization_nllb import NllbTokenizer
+else:
+ NllbTokenizer = None
+
+
+logger = logging.get_logger(__name__)
+
+
+VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
+
+
+FAIRSEQ_LANGUAGE_CODES = ['ace_Arab', 'ace_Latn', 'acm_Arab', 'acq_Arab', 'aeb_Arab', 'afr_Latn', 'ajp_Arab', 'aka_Latn', 'amh_Ethi', 'apc_Arab', 'arb_Arab', 'ars_Arab', 'ary_Arab', 'arz_Arab', 'asm_Beng', 'ast_Latn', 'awa_Deva', 'ayr_Latn', 'azb_Arab', 'azj_Latn', 'bak_Cyrl', 'bam_Latn', 'ban_Latn', 'bel_Cyrl', 'bem_Latn', 'ben_Beng', 'bho_Deva', 'bjn_Arab', 'bjn_Latn', 'bod_Tibt', 'bos_Latn', 'bug_Latn', 'bul_Cyrl', 'cat_Latn', 'ceb_Latn', 'ces_Latn', 'cjk_Latn', 'ckb_Arab', 'crh_Latn', 'cym_Latn', 'dan_Latn', 'deu_Latn', 'dik_Latn', 'dyu_Latn', 'dzo_Tibt', 'ell_Grek', 'eng_Latn', 'epo_Latn', 'est_Latn', 'eus_Latn', 'ewe_Latn', 'fao_Latn', 'pes_Arab', 'fij_Latn', 'fin_Latn', 'fon_Latn', 'fra_Latn', 'fur_Latn', 'fuv_Latn', 'gla_Latn', 'gle_Latn', 'glg_Latn', 'grn_Latn', 'guj_Gujr', 'hat_Latn', 'hau_Latn', 'heb_Hebr', 'hin_Deva', 'hne_Deva', 'hrv_Latn', 'hun_Latn', 'hye_Armn', 'ibo_Latn', 'ilo_Latn', 'ind_Latn', 'isl_Latn', 'ita_Latn', 'jav_Latn', 'jpn_Jpan', 'kab_Latn', 'kac_Latn', 'kam_Latn', 'kan_Knda', 'kas_Arab', 'kas_Deva', 'kat_Geor', 'knc_Arab', 'knc_Latn', 'kaz_Cyrl', 'kbp_Latn', 'kea_Latn', 'khm_Khmr', 'kik_Latn', 'kin_Latn', 'kir_Cyrl', 'kmb_Latn', 'kon_Latn', 'kor_Hang', 'kmr_Latn', 'lao_Laoo', 'lvs_Latn', 'lij_Latn', 'lim_Latn', 'lin_Latn', 'lit_Latn', 'lmo_Latn', 'ltg_Latn', 'ltz_Latn', 'lua_Latn', 'lug_Latn', 'luo_Latn', 'lus_Latn', 'mag_Deva', 'mai_Deva', 'mal_Mlym', 'mar_Deva', 'min_Latn', 'mkd_Cyrl', 'plt_Latn', 'mlt_Latn', 'mni_Beng', 'khk_Cyrl', 'mos_Latn', 'mri_Latn', 'zsm_Latn', 'mya_Mymr', 'nld_Latn', 'nno_Latn', 'nob_Latn', 'npi_Deva', 'nso_Latn', 'nus_Latn', 'nya_Latn', 'oci_Latn', 'gaz_Latn', 'ory_Orya', 'pag_Latn', 'pan_Guru', 'pap_Latn', 'pol_Latn', 'por_Latn', 'prs_Arab', 'pbt_Arab', 'quy_Latn', 'ron_Latn', 'run_Latn', 'rus_Cyrl', 'sag_Latn', 'san_Deva', 'sat_Beng', 'scn_Latn', 'shn_Mymr', 'sin_Sinh', 'slk_Latn', 'slv_Latn', 'smo_Latn', 'sna_Latn', 'snd_Arab', 'som_Latn', 'sot_Latn', 'spa_Latn', 'als_Latn', 'srd_Latn', 'srp_Cyrl', 'ssw_Latn', 'sun_Latn', 'swe_Latn', 'swh_Latn', 'szl_Latn', 'tam_Taml', 'tat_Cyrl', 'tel_Telu', 'tgk_Cyrl', 'tgl_Latn', 'tha_Thai', 'tir_Ethi', 'taq_Latn', 'taq_Tfng', 'tpi_Latn', 'tsn_Latn', 'tso_Latn', 'tuk_Latn', 'tum_Latn', 'tur_Latn', 'twi_Latn', 'tzm_Tfng', 'uig_Arab', 'ukr_Cyrl', 'umb_Latn', 'urd_Arab', 'uzn_Latn', 'vec_Latn', 'vie_Latn', 'war_Latn', 'wol_Latn', 'xho_Latn', 'ydd_Hebr', 'yor_Latn', 'yue_Hant', 'zho_Hans', 'zho_Hant', 'zul_Latn'] # fmt: skip
+
+
+class NllbTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" NLLB tokenizer (backed by HuggingFace's *tokenizers* library). Based on
+ [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ The tokenization method is ` ` for source language documents, and `
+ ` for target language documents.
+
+ Examples:
+
+ ```python
+ >>> from transformers import NllbTokenizerFast
+
+ >>> tokenizer = NllbTokenizerFast.from_pretrained(
+ ... "facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="fra_Latn"
+ ... )
+ >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
+ >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
+ >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
+ ```
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenizer_file (`str`, *optional*):
+ The path to a tokenizer file to use instead of the vocab file.
+ src_lang (`str`, *optional*):
+ The language to use as source language for translation.
+ tgt_lang (`str`, *optional*):
+ The language to use as target language for translation.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = NllbTokenizer
+
+ prefix_tokens: list[int] = []
+ suffix_tokens: list[int] = []
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ src_lang=None,
+ tgt_lang=None,
+ additional_special_tokens=None,
+ legacy_behaviour=False,
+ **kwargs,
+ ):
+ if additional_special_tokens is None:
+ additional_special_tokens = FAIRSEQ_LANGUAGE_CODES
+
+ self.vocab_file = vocab_file
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = (
+ AddedToken(mask_token, normalized=True, lstrip=True, special=True)
+ if isinstance(mask_token, str)
+ else mask_token
+ )
+ self.legacy_behaviour = legacy_behaviour
+ super().__init__(
+ vocab_file=vocab_file,
+ tokenizer_file=tokenizer_file,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ src_lang=src_lang,
+ tgt_lang=tgt_lang,
+ mask_token=mask_token,
+ additional_special_tokens=additional_special_tokens,
+ legacy_behaviour=legacy_behaviour,
+ **kwargs,
+ )
+
+ self._src_lang = src_lang if src_lang is not None else "eng_Latn"
+ self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang)
+ self.tgt_lang = tgt_lang
+ self.set_src_lang_special_tokens(self._src_lang)
+
+ @property
+ def src_lang(self) -> str:
+ return self._src_lang
+
+ @src_lang.setter
+ def src_lang(self, new_src_lang: str) -> None:
+ self._src_lang = new_src_lang
+ self.set_src_lang_special_tokens(self._src_lang)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. The special tokens depend on calling set_lang.
+
+ An NLLB sequence has the following format, where `X` represents the sequence:
+
+ - `input_ids` (for encoder) `X [eos, src_lang_code]`
+ - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
+
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+ separator.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return self.prefix_tokens + token_ids_0 + self.suffix_tokens
+ # We don't expect to process pairs, but leave the pair logic for API consistency
+ return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not
+ make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of zeros.
+
+ """
+
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def _build_translation_inputs(
+ self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
+ ):
+ """Used by translation pipeline, to prepare inputs for the generate function"""
+ if src_lang is None or tgt_lang is None:
+ raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
+ self.src_lang = src_lang
+ inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
+ tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
+ inputs["forced_bos_token_id"] = tgt_lang_id
+ return inputs
+
+ def prepare_seq2seq_batch(
+ self,
+ src_texts: list[str],
+ src_lang: str = "eng_Latn",
+ tgt_texts: Optional[list[str]] = None,
+ tgt_lang: str = "fra_Latn",
+ **kwargs,
+ ) -> BatchEncoding:
+ self.src_lang = src_lang
+ self.tgt_lang = tgt_lang
+ return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
+
+ def _switch_to_input_mode(self):
+ return self.set_src_lang_special_tokens(self.src_lang)
+
+ def _switch_to_target_mode(self):
+ return self.set_tgt_lang_special_tokens(self.tgt_lang)
+
+ def set_src_lang_special_tokens(self, src_lang) -> None:
+ """Reset the special tokens to the source lang setting.
+ - In legacy mode: No prefix and suffix=[eos, src_lang_code].
+ - In default mode: Prefix=[src_lang_code], suffix = [eos]
+ """
+ self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
+
+ if self.legacy_behaviour:
+ self.prefix_tokens = []
+ self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
+ else:
+ self.prefix_tokens = [self.cur_lang_code]
+ self.suffix_tokens = [self.eos_token_id]
+
+ prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
+ suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
+
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
+ pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
+ special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
+ )
+
+ def set_tgt_lang_special_tokens(self, lang: str) -> None:
+ """Reset the special tokens to the target lang setting.
+ - In legacy mode: No prefix and suffix=[eos, tgt_lang_code].
+ - In default mode: Prefix=[tgt_lang_code], suffix = [eos]
+ """
+ self.cur_lang_code = self.convert_tokens_to_ids(lang)
+ if self.legacy_behaviour:
+ self.prefix_tokens = []
+ self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
+ else:
+ self.prefix_tokens = [self.cur_lang_code]
+ self.suffix_tokens = [self.eos_token_id]
+
+ prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
+ suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
+
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
+ pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
+ special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
+ )
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not self.can_save_slow_tokenizer:
+ raise ValueError(
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+ "tokenizer."
+ )
+
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory.")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
+
+
+__all__ = ["NllbTokenizerFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/olmo2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/olmo2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2161a4948b5e32f600af135c33330c2e2c353c7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/olmo2/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_olmo2 import *
+ from .modeling_olmo2 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/olmo2/configuration_olmo2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/olmo2/configuration_olmo2.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7a0dabaf4e6d054a5dbb739ed7cab6619e73d8a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/olmo2/configuration_olmo2.py
@@ -0,0 +1,180 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/olmo2/modular_olmo2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_olmo2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+
+from ...configuration_utils import PretrainedConfig
+
+
+class Olmo2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Olmo2Model`]. It is used to instantiate an OLMo2
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50304):
+ Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Olmo2Model`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 1):
+ Padding token id.
+ bos_token_id (`int`, *optional*):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 50279):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+
+ ```python
+ >>> from transformers import Olmo2Model, Olmo2Config
+
+ >>> # Initializing a Olmo2 7B style configuration
+ >>> configuration = Olmo2Config()
+
+ >>> # Initializing a model from the Olmo2 7B style configuration
+ >>> model = Olmo2Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "olmo2"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=50304,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ use_cache=True,
+ pad_token_id=1,
+ bos_token_id=None,
+ eos_token_id=50279,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ rms_norm_eps=1e-5,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self._rope_scaling_validation()
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ self.rms_norm_eps = rms_norm_eps
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
+ raise ValueError(
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
+ )
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
+
+
+__all__ = ["Olmo2Config"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/olmo2/modeling_olmo2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/olmo2/modeling_olmo2.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fe4cfaf91dea9804a96963c947bacf4e99174e4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/olmo2/modeling_olmo2.py
@@ -0,0 +1,470 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/olmo2/modular_olmo2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_olmo2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from transformers.utils.generic import TransformersKwargs
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_olmo2 import Olmo2Config
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class Olmo2RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Olmo2RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return (self.weight * hidden_states).to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ q_type, k_type = q.dtype, k.dtype
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed.to(q_type), k_embed.to(k_type)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+class Olmo2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
+ self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_norm(self.q_proj(hidden_states))
+ key_states = self.k_norm(self.k_proj(hidden_states))
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Olmo2MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class Olmo2DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Olmo2Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx)
+
+ self.mlp = Olmo2MLP(config)
+ self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Olmo2RotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Olmo2Config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+ return cos, sin
+
+
+@auto_docstring
+class Olmo2PreTrainedModel(PreTrainedModel):
+ config: Olmo2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Olmo2DecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": Olmo2DecoderLayer,
+ "attentions": Olmo2Attention,
+ }
+
+
+@auto_docstring
+class Olmo2Model(Olmo2PreTrainedModel):
+ def __init__(self, config: Olmo2Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Olmo2RotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Olmo2Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Olmo2ForCausalLM
+
+ >>> model = Olmo2ForCausalLM.from_pretrained("meta-olmo2/Olmo2-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo2/Olmo2-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["Olmo2ForCausalLM", "Olmo2Model", "Olmo2PreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/olmo2/modular_olmo2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/olmo2/modular_olmo2.py
new file mode 100644
index 0000000000000000000000000000000000000000..84aa2509007dfc196945e02d9e60fde62d157f6c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/olmo2/modular_olmo2.py
@@ -0,0 +1,321 @@
+from typing import Callable, Optional
+
+import torch
+import torch.nn as nn
+
+from transformers.utils.generic import TransformersKwargs
+
+from ...cache_utils import Cache
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import logging
+from ...utils.deprecation import deprecate_kwarg
+from ..llama.modeling_llama import LlamaPreTrainedModel, LlamaRMSNorm, eager_attention_forward
+from ..olmo.configuration_olmo import OlmoConfig
+from ..olmo.modeling_olmo import (
+ OlmoAttention,
+ OlmoDecoderLayer,
+ OlmoForCausalLM,
+ OlmoModel,
+ OlmoRotaryEmbedding,
+ apply_rotary_pos_emb,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class Olmo2Config(OlmoConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Olmo2Model`]. It is used to instantiate an OLMo2
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50304):
+ Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Olmo2Model`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 1):
+ Padding token id.
+ bos_token_id (`int`, *optional*):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 50279):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+
+ ```python
+ >>> from transformers import Olmo2Model, Olmo2Config
+
+ >>> # Initializing a Olmo2 7B style configuration
+ >>> configuration = Olmo2Config()
+
+ >>> # Initializing a model from the Olmo2 7B style configuration
+ >>> model = Olmo2Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "olmo2"
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=50304,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ use_cache=True,
+ pad_token_id=1,
+ bos_token_id=None,
+ eos_token_id=50279,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ rms_norm_eps=1e-5,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ hidden_act=hidden_act,
+ max_position_embeddings=max_position_embeddings,
+ initializer_range=initializer_range,
+ use_cache=use_cache,
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ **kwargs,
+ )
+
+ self.rms_norm_eps = rms_norm_eps
+ del self.clip_qkv
+
+
+# OLMo2 RMS norm is identical to Llama RMS norm except:
+# - Weight and hidden states are multiplied before converting back to the input dtype, rather than after.
+class Olmo2RMSNorm(LlamaRMSNorm):
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return (self.weight * hidden_states).to(input_dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Olmo2 attention is identical to OLMo attention except:
+# - Norm is applied to attention queries and keys.
+# - No qkv clipping.
+class Olmo2Attention(OlmoAttention):
+ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx=layer_idx)
+ self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
+ self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_norm(self.q_proj(hidden_states))
+ key_states = self.k_norm(self.k_proj(hidden_states))
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+# The OLMo2 layers are identical to those of the OLMo model except:
+# - RMSNorm is used instead of standard layer norm.
+# - Norm is applied after attention/feedforward rather than before.
+class Olmo2DecoderLayer(OlmoDecoderLayer):
+ def __init__(self, config: Olmo2Config, layer_idx: int):
+ super().__init__(config, layer_idx=layer_idx)
+ self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx)
+ del self.input_layernorm
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Olmo2RotaryEmbedding(OlmoRotaryEmbedding):
+ pass
+
+
+class Olmo2PreTrainedModel(LlamaPreTrainedModel):
+ pass
+
+
+# The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of
+# standard layer norm for the output norm.
+class Olmo2Model(OlmoModel):
+ def __init__(self, config: Olmo2Config):
+ super().__init__(config)
+ self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.layers = nn.ModuleList(
+ [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+
+
+# The heads now only need to redefine the model inside to the correct `RobertaModel`
+class Olmo2ForCausalLM(OlmoForCausalLM):
+ pass
+
+
+__all__ = [
+ "Olmo2Config",
+ "Olmo2ForCausalLM",
+ "Olmo2Model",
+ "Olmo2PreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a07b0ab669f3a8b386e0fc99e99f5d5a780ef9c5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_openai import *
+ from .modeling_openai import *
+ from .modeling_tf_openai import *
+ from .tokenization_openai import *
+ from .tokenization_openai_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/configuration_openai.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/configuration_openai.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4f2fae9d304bc63b6e69a3e456141531a644bfd
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/configuration_openai.py
@@ -0,0 +1,156 @@
+# coding=utf-8
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""OpenAI GPT configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class OpenAIGPTConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`OpenAIGPTModel`] or a [`TFOpenAIGPTModel`]. It is
+ used to instantiate a GPT model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the GPT
+ [openai-community/openai-gpt](https://huggingface.co/openai-community/openai-gpt) architecture from OpenAI.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 40478):
+ Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`OpenAIGPTModel`] or [`TFOpenAIGPTModel`].
+ n_positions (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ n_embd (`int`, *optional*, defaults to 768):
+ Dimensionality of the embeddings and hidden states.
+ n_layer (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ n_head (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ afn (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ embd_pdrop (`int`, *optional*, defaults to 0.1):
+ The dropout ratio for the embeddings.
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
+ The epsilon to use in the layer normalization layers
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ summary_type (`str`, *optional*, defaults to `"cls_index"`):
+ Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
+ [`OpenAIGPTDoubleHeadsModel`].
+
+ Has to be one of the following options:
+
+ - `"last"`: Take the last token hidden state (like XLNet).
+ - `"first"`: Take the first token hidden state (like BERT).
+ - `"mean"`: Take the mean of all tokens hidden states.
+ - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
+ - `"attn"`: Not implemented now, use multi-head attention.
+ summary_use_proj (`bool`, *optional*, defaults to `True`):
+ Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
+ [`OpenAIGPTDoubleHeadsModel`].
+
+ Whether or not to add a projection after the vector extraction.
+ summary_activation (`str`, *optional*):
+ Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
+ [`OpenAIGPTDoubleHeadsModel`].
+
+ Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
+ summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
+ Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
+ [`OpenAIGPTDoubleHeadsModel`].
+
+ Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
+ summary_first_dropout (`float`, *optional*, defaults to 0.1):
+ Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
+ [`OpenAIGPTDoubleHeadsModel`].
+
+ The dropout ratio to be used after the projection and activation.
+
+
+ Examples:
+
+ ```python
+ >>> from transformers import OpenAIGPTConfig, OpenAIGPTModel
+
+ >>> # Initializing a GPT configuration
+ >>> configuration = OpenAIGPTConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = OpenAIGPTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "openai-gpt"
+ attribute_map = {
+ "max_position_embeddings": "n_positions",
+ "hidden_size": "n_embd",
+ "num_attention_heads": "n_head",
+ "num_hidden_layers": "n_layer",
+ }
+
+ def __init__(
+ self,
+ vocab_size=40478,
+ n_positions=512,
+ n_embd=768,
+ n_layer=12,
+ n_head=12,
+ afn="gelu",
+ resid_pdrop=0.1,
+ embd_pdrop=0.1,
+ attn_pdrop=0.1,
+ layer_norm_epsilon=1e-5,
+ initializer_range=0.02,
+ summary_type="cls_index",
+ summary_use_proj=True,
+ summary_activation=None,
+ summary_proj_to_labels=True,
+ summary_first_dropout=0.1,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.n_positions = n_positions
+ self.n_embd = n_embd
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.afn = afn
+ self.resid_pdrop = resid_pdrop
+ self.embd_pdrop = embd_pdrop
+ self.attn_pdrop = attn_pdrop
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_range = initializer_range
+ self.summary_type = summary_type
+ self.summary_use_proj = summary_use_proj
+ self.summary_activation = summary_activation
+ self.summary_first_dropout = summary_first_dropout
+ self.summary_proj_to_labels = summary_proj_to_labels
+ super().__init__(**kwargs)
+
+
+__all__ = ["OpenAIGPTConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/modeling_openai.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/modeling_openai.py
new file mode 100644
index 0000000000000000000000000000000000000000..44fa05227ff831bbed49d229ad5a22e846977f14
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/modeling_openai.py
@@ -0,0 +1,853 @@
+# coding=utf-8
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch OpenAI GPT model."""
+
+import json
+import math
+import os
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import gelu_new, get_activation, silu
+from ...generation import GenerationMixin
+from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
+from ...utils import (
+ ModelOutput,
+ auto_docstring,
+ logging,
+)
+from .configuration_openai import OpenAIGPTConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
+ """Load tf pre-trained weights in a pytorch model (from NumPy arrays here)"""
+ import re
+
+ import numpy as np
+
+ if ".ckpt" in openai_checkpoint_folder_path:
+ openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path)
+
+ logger.info(f"Loading weights from {openai_checkpoint_folder_path}")
+
+ with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle:
+ names = json.load(names_handle)
+ with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle:
+ shapes = json.load(shapes_handle)
+ offsets = np.cumsum([np.prod(shape) for shape in shapes])
+ init_params = [np.load(openai_checkpoint_folder_path + f"/params_{n}.npy") for n in range(10)]
+ init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
+ init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
+
+ # This was used when we had a single embedding matrix for positions and tokens
+ # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
+ # del init_params[1]
+ init_params = [arr.squeeze() for arr in init_params]
+
+ # Check that the token and position embeddings weight dimensions map those of the init parameters.
+ if model.tokens_embed.weight.shape != init_params[1].shape:
+ raise ValueError(
+ f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape:"
+ f" {init_params[1].shape}"
+ )
+
+ if model.positions_embed.weight.shape != init_params[0].shape:
+ raise ValueError(
+ f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape:"
+ f" {init_params[0].shape}"
+ )
+
+ model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
+ model.positions_embed.weight.data = torch.from_numpy(init_params[0])
+ names.pop(0)
+ # Pop position and token embedding arrays
+ init_params.pop(0)
+ init_params.pop(0)
+
+ for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
+ name = name[6:] # skip "model/"
+ if name[-2:] != ":0":
+ raise ValueError(f"Layer {name} does not end with :0")
+ name = name[:-2]
+ name = name.split("/")
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
+ scope_names = re.split(r"(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "g":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "b":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "w":
+ pointer = getattr(pointer, "weight")
+ else:
+ pointer = getattr(pointer, scope_names[0])
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+
+ # Ensure that the pointer and array have compatible shapes.
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+ACT_FNS = {"relu": nn.ReLU(), "silu": silu, "gelu": gelu_new, "swish": silu}
+
+
+class Attention(nn.Module):
+ def __init__(self, nx, n_positions, config, scale=False):
+ super().__init__()
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
+ # [switch nx => n_state from Block to Attention to keep identical to TF implementation]
+ if n_state % config.n_head != 0:
+ raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}")
+ self.register_buffer(
+ "bias",
+ torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions),
+ persistent=False,
+ )
+ self.n_head = config.n_head
+ self.split_size = n_state
+ self.scale = scale
+
+ self.c_attn = Conv1D(n_state * 3, nx)
+ self.c_proj = Conv1D(n_state, nx)
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
+ )
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
+ # Prune conv1d layers
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
+ # Update hyper params
+ self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
+ self.n_head = self.n_head - len(heads)
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
+ w = torch.matmul(q, k)
+ if self.scale:
+ w = w / math.sqrt(v.size(-1))
+ # w = w * self.bias + -1e9 * (1 - self.bias) # TF implementation method: mask_attn_weights
+ # XD: self.b may be larger than w, so we need to crop it
+ b = self.bias[:, :, : w.size(-2), : w.size(-1)]
+ w = w * b + -1e4 * (1 - b)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ w = w + attention_mask
+
+ w = nn.functional.softmax(w, dim=-1)
+ w = self.attn_dropout(w)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ w = w * head_mask
+
+ outputs = [torch.matmul(w, v)]
+ if output_attentions:
+ outputs.append(w)
+ return outputs
+
+ def merge_heads(self, x):
+ x = x.permute(0, 2, 1, 3).contiguous()
+ new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
+ return x.view(*new_x_shape) # in Tensorflow implementation: fct merge_states
+
+ def split_heads(self, x, k=False):
+ new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
+ x = x.view(*new_x_shape) # in Tensorflow implementation: fct split_states
+ if k:
+ return x.permute(0, 2, 3, 1)
+ else:
+ return x.permute(0, 2, 1, 3)
+
+ def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
+ x = self.c_attn(x)
+ query, key, value = x.split(self.split_size, dim=2)
+ query = self.split_heads(query)
+ key = self.split_heads(key, k=True)
+ value = self.split_heads(value)
+
+ attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
+ a = attn_outputs[0]
+
+ a = self.merge_heads(a)
+ a = self.c_proj(a)
+ a = self.resid_dropout(a)
+
+ outputs = [a] + attn_outputs[1:]
+ return outputs # a, (attentions)
+
+
+class MLP(nn.Module):
+ def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
+ super().__init__()
+ nx = config.n_embd
+ self.c_fc = Conv1D(n_state, nx)
+ self.c_proj = Conv1D(nx, n_state)
+ self.act = ACT_FNS[config.afn]
+ self.dropout = nn.Dropout(config.resid_pdrop)
+
+ def forward(self, x):
+ h = self.act(self.c_fc(x))
+ h2 = self.c_proj(h)
+ return self.dropout(h2)
+
+
+class Block(nn.Module):
+ def __init__(self, n_positions, config, scale=False):
+ super().__init__()
+ nx = config.n_embd
+ self.attn = Attention(nx, n_positions, config, scale)
+ self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
+ self.mlp = MLP(4 * nx, config)
+ self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
+
+ def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
+ attn_outputs = self.attn(
+ x,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ )
+ a = attn_outputs[0]
+
+ n = self.ln_1(x + a)
+ m = self.mlp(n)
+ h = self.ln_2(n + m)
+
+ outputs = [h] + attn_outputs[1:]
+ return outputs
+
+
+# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->OpenAIGPT
+class OpenAIGPTSequenceSummary(nn.Module):
+ r"""
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ config ([`OpenAIGPTConfig`]):
+ The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
+ config class of your model for the default values it uses):
+
+ - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
+
+ - `"last"` -- Take the last token hidden state (like XLNet)
+ - `"first"` -- Take the first token hidden state (like Bert)
+ - `"mean"` -- Take the mean of all tokens hidden states
+ - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
+ - `"attn"` -- Not implemented now, use multi-head attention
+
+ - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
+ - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
+ (otherwise to `config.hidden_size`).
+ - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
+ another string or `None` will add no activation.
+ - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
+ - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
+ """
+
+ def __init__(self, config: OpenAIGPTConfig):
+ super().__init__()
+
+ self.summary_type = getattr(config, "summary_type", "last")
+ if self.summary_type == "attn":
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
+ raise NotImplementedError
+
+ self.summary = nn.Identity()
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
+ num_classes = config.num_labels
+ else:
+ num_classes = config.hidden_size
+ self.summary = nn.Linear(config.hidden_size, num_classes)
+
+ activation_string = getattr(config, "summary_activation", None)
+ self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
+
+ self.first_dropout = nn.Identity()
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
+
+ self.last_dropout = nn.Identity()
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
+ ) -> torch.FloatTensor:
+ """
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
+ The hidden states of the last layer.
+ cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
+ Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
+
+ Returns:
+ `torch.FloatTensor`: The summary of the sequence hidden states.
+ """
+ if self.summary_type == "last":
+ output = hidden_states[:, -1]
+ elif self.summary_type == "first":
+ output = hidden_states[:, 0]
+ elif self.summary_type == "mean":
+ output = hidden_states.mean(dim=1)
+ elif self.summary_type == "cls_index":
+ if cls_index is None:
+ cls_index = torch.full_like(
+ hidden_states[..., :1, :],
+ hidden_states.shape[-2] - 1,
+ dtype=torch.long,
+ )
+ else:
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
+ elif self.summary_type == "attn":
+ raise NotImplementedError
+
+ output = self.first_dropout(output)
+ output = self.summary(output)
+ output = self.activation(output)
+ output = self.last_dropout(output)
+
+ return output
+
+
+@auto_docstring
+class OpenAIGPTPreTrainedModel(PreTrainedModel):
+ config: OpenAIGPTConfig
+ load_tf_weights = load_tf_weights_in_openai_gpt
+ base_model_prefix = "transformer"
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, (nn.Linear, Conv1D)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of models predicting if two sentences are consecutive or not.
+ """
+)
+class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
+ Multiple choice classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ mc_loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ mc_logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+@auto_docstring
+class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
+ self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
+ self.drop = nn.Dropout(config.embd_pdrop)
+ self.h = nn.ModuleList([Block(config.n_positions, config, scale=True) for _ in range(config.n_layer)])
+
+ self.register_buffer("position_ids", torch.arange(config.n_positions), persistent=False)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.tokens_embed
+
+ def set_input_embeddings(self, new_embeddings):
+ self.tokens_embed = new_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
+ """
+ for layer, heads in heads_to_prune.items():
+ self.h[layer].attn.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if position_ids is None:
+ # Code is different from when we had a single embedding matrix from position and token embeddings
+ position_ids = self.position_ids[None, : input_shape[-1]]
+
+ # Attention mask.
+ if attention_mask is not None:
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+
+ # Prepare head mask if needed
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.tokens_embed(input_ids)
+ position_embeds = self.positions_embed(position_ids)
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
+ token_type_embeds = self.tokens_embed(token_type_ids)
+ else:
+ token_type_embeds = 0
+ hidden_states = inputs_embeds + position_embeds + token_type_embeds
+ hidden_states = self.drop(hidden_states)
+
+ output_shape = input_shape + (hidden_states.size(-1),)
+
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, block in enumerate(self.h):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions)
+ hidden_states = outputs[0]
+ if output_attentions:
+ all_attentions = all_attentions + (outputs[1],)
+
+ hidden_states = hidden_states.view(*output_shape)
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """
+)
+class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = OpenAIGPTModel(config)
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], CausalLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ lm_logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Flatten the tokens
+ loss = self.loss_function(
+ lm_logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss,
+ logits=lm_logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict[str, Any]:
+ # Overwritten -- old model with reduced inputs
+ model_inputs = {"input_ids": input_ids}
+
+ # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
+ for key, value in kwargs.items():
+ if key not in model_inputs:
+ model_inputs[key] = value
+
+ return model_inputs
+
+
+@auto_docstring(
+ custom_intro="""
+ OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
+ input sequence).
+ """
+)
+class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ config.num_labels = 1
+ self.transformer = OpenAIGPTModel(config)
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+ self.multiple_choice_head = OpenAIGPTSequenceSummary(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ mc_token_ids: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ mc_labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], OpenAIGPTDoubleHeadsModelOutput]:
+ r"""
+ mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
+ 1]`.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-1, 0, ..., config.vocab_size]` All labels set to `-100` are
+ ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
+ where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, OpenAIGPTDoubleHeadsModel
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
+ >>> model = OpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt")
+ >>> tokenizer.add_special_tokens(
+ ... {"cls_token": "[CLS]"}
+ ... ) # Add a [CLS] to the vocabulary (we should train it also!)
+ >>> model.resize_token_embeddings(len(tokenizer))
+
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
+ >>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
+ >>> mc_token_ids = torch.tensor([input_ids.size(-1) - 1, input_ids.size(-1) - 1]).unsqueeze(0) # Batch size 1
+
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
+ >>> lm_logits = outputs.logits
+ >>> mc_logits = outputs.mc_logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states)
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
+
+ lm_loss, mc_loss = None, None
+ if mc_labels is not None:
+ loss_fct = CrossEntropyLoss()
+ mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
+ if labels is not None:
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits, mc_logits) + transformer_outputs[1:]
+ if mc_loss is not None:
+ output = (mc_loss,) + output
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return OpenAIGPTDoubleHeadsModelOutput(
+ loss=lm_loss,
+ mc_loss=mc_loss,
+ logits=lm_logits,
+ mc_logits=mc_logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Original OpenAI GPT Model transformer with a sequence classification head on top (linear layer).
+ [`OpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal
+ models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the
+ last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding
+ token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since
+ it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take
+ the last value in each row of the batch).
+ """
+)
+class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = OpenAIGPTModel(config)
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ # Ensure the batch size is > 1 if there is no padding.
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=pooled_logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+__all__ = [
+ "OpenAIGPTDoubleHeadsModel",
+ "OpenAIGPTForSequenceClassification",
+ "OpenAIGPTLMHeadModel",
+ "OpenAIGPTModel",
+ "OpenAIGPTPreTrainedModel",
+ "load_tf_weights_in_openai_gpt",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/modeling_tf_openai.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/modeling_tf_openai.py
new file mode 100644
index 0000000000000000000000000000000000000000..0235159633b4fc6410f1e973dca90c29a1ddaa08
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/modeling_tf_openai.py
@@ -0,0 +1,936 @@
+# coding=utf-8
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 OpenAI GPT model."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput
+from ...modeling_tf_utils import (
+ TFCausalLanguageModelingLoss,
+ TFConv1D,
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ TFSequenceSummary,
+ TFSharedEmbeddings,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_openai import OpenAIGPTConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "openai-community/openai-gpt"
+_CONFIG_FOR_DOC = "OpenAIGPTConfig"
+
+
+class TFAttention(keras.layers.Layer):
+ def __init__(self, nx, config, scale=False, **kwargs):
+ super().__init__(**kwargs)
+
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
+ # [switch nx => n_state from Block to Attention to keep identical to TF implementation]
+ assert n_state % config.n_head == 0, (
+ f"Hidden dimension {n_state} not dividable by number of heads {config.n_head}"
+ )
+ self.n_head = config.n_head
+ self.split_size = n_state
+ self.scale = scale
+ self.output_attentions = config.output_attentions
+
+ self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
+ self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
+ self.attn_dropout = keras.layers.Dropout(config.attn_pdrop)
+ self.resid_dropout = keras.layers.Dropout(config.resid_pdrop)
+ self.n_state = n_state
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ pass
+
+ @staticmethod
+ def causal_attention_mask(nd, ns):
+ """
+ 1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
+ -1, ns-nd), but doesn't produce garbage on TPUs.
+ """
+ i = tf.range(nd)[:, None]
+ j = tf.range(ns)
+ m = i >= j - ns + nd
+ return m
+
+ def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
+ # q, k, v have shape [batch, heads, sequence, features]
+ w = tf.matmul(q, k, transpose_b=True)
+ if self.scale:
+ dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores
+ w = w / tf.math.sqrt(dk)
+
+ # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
+ _, _, nd, ns = shape_list(w)
+ b = tf.cast(self.causal_attention_mask(nd, ns), dtype=w.dtype)
+ b = tf.reshape(b, [1, 1, nd, ns])
+ w = w * b - 1e4 * (1 - b)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attention_mask = tf.cast(attention_mask, dtype=w.dtype)
+ w = w + attention_mask
+
+ w = stable_softmax(w, axis=-1)
+ w = self.attn_dropout(w, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ w = w * head_mask
+
+ outputs = [tf.matmul(w, v)]
+ if output_attentions:
+ outputs.append(w)
+ return outputs
+
+ def merge_heads(self, x):
+ x = tf.transpose(x, [0, 2, 1, 3])
+ x_shape = shape_list(x)
+ new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
+ return tf.reshape(x, new_x_shape)
+
+ def split_heads(self, x):
+ x_shape = shape_list(x)
+ new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
+ x = tf.reshape(x, new_x_shape)
+ return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
+
+ def call(self, x, attention_mask, head_mask, output_attentions, training=False):
+ x = self.c_attn(x)
+ query, key, value = tf.split(x, 3, axis=2)
+ query = self.split_heads(query)
+ key = self.split_heads(key)
+ value = self.split_heads(value)
+
+ attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training)
+ a = attn_outputs[0]
+
+ a = self.merge_heads(a)
+ a = self.c_proj(a)
+ a = self.resid_dropout(a, training=training)
+
+ outputs = [a] + attn_outputs[1:]
+ return outputs # a, (attentions)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "c_attn", None) is not None:
+ with tf.name_scope(self.c_attn.name):
+ self.c_attn.build([None, None, self.n_state * 3])
+ if getattr(self, "c_proj", None) is not None:
+ with tf.name_scope(self.c_proj.name):
+ self.c_proj.build([None, None, self.n_state])
+
+
+class TFMLP(keras.layers.Layer):
+ def __init__(self, n_state, config, **kwargs):
+ super().__init__(**kwargs)
+ nx = config.n_embd
+ self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc")
+ self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj")
+ self.act = get_tf_activation("gelu")
+ self.dropout = keras.layers.Dropout(config.resid_pdrop)
+ self.nx = nx
+ self.n_state = n_state
+
+ def call(self, x, training=False):
+ h = self.act(self.c_fc(x))
+ h2 = self.c_proj(h)
+ h2 = self.dropout(h2, training=training)
+ return h2
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "c_fc", None) is not None:
+ with tf.name_scope(self.c_fc.name):
+ self.c_fc.build([None, None, self.n_state])
+ if getattr(self, "c_proj", None) is not None:
+ with tf.name_scope(self.c_proj.name):
+ self.c_proj.build([None, None, self.nx])
+
+
+class TFBlock(keras.layers.Layer):
+ def __init__(self, config, scale=False, **kwargs):
+ super().__init__(**kwargs)
+ nx = config.n_embd
+ self.attn = TFAttention(nx, config, scale, name="attn")
+ self.ln_1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
+ self.mlp = TFMLP(4 * nx, config, name="mlp")
+ self.ln_2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
+ self.nx = nx
+
+ def call(self, x, attention_mask, head_mask, output_attentions, training=False):
+ output_attn = self.attn(x, attention_mask, head_mask, output_attentions, training=training)
+ a = output_attn[0] # output_attn: a, (attentions)
+
+ n = self.ln_1(x + a)
+ m = self.mlp(n, training=training)
+ h = self.ln_2(n + m)
+
+ outputs = [h] + output_attn[1:]
+ return outputs # x, (attentions)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attn", None) is not None:
+ with tf.name_scope(self.attn.name):
+ self.attn.build(None)
+ if getattr(self, "ln_1", None) is not None:
+ with tf.name_scope(self.ln_1.name):
+ self.ln_1.build([None, None, self.nx])
+ if getattr(self, "mlp", None) is not None:
+ with tf.name_scope(self.mlp.name):
+ self.mlp.build(None)
+ if getattr(self, "ln_2", None) is not None:
+ with tf.name_scope(self.ln_2.name):
+ self.ln_2.build([None, None, self.nx])
+
+
+@keras_serializable
+class TFOpenAIGPTMainLayer(keras.layers.Layer):
+ config_class = OpenAIGPTConfig
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ self.config = config
+ self.output_hidden_states = config.output_hidden_states
+ self.output_attentions = config.output_attentions
+ self.return_dict = config.use_return_dict
+ self.num_hidden_layers = config.n_layer
+ self.n_embd = config.n_embd
+ self.n_positions = config.n_positions
+ self.initializer_range = config.initializer_range
+
+ self.tokens_embed = TFSharedEmbeddings(
+ config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="tokens_embed"
+ )
+ self.drop = keras.layers.Dropout(config.embd_pdrop)
+ self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)]
+
+ def build(self, input_shape=None):
+ with tf.name_scope("positions_embed"):
+ self.positions_embed = self.add_weight(
+ name="embeddings",
+ shape=[self.n_positions, self.n_embd],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "tokens_embed", None) is not None:
+ with tf.name_scope(self.tokens_embed.name):
+ self.tokens_embed.build(None)
+ if getattr(self, "h", None) is not None:
+ for layer in self.h:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+ def get_input_embeddings(self):
+ return self.tokens_embed
+
+ def set_input_embeddings(self, value):
+ self.tokens_embed.weight = value
+ self.tokens_embed.vocab_size = shape_list(value)[0]
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
+ """
+ raise NotImplementedError
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> tuple | TFBaseModelOutput:
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if position_ids is None:
+ position_ids = tf.expand_dims(tf.range(input_shape[-1]), axis=0)
+
+ if attention_mask is not None:
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+
+ one_cst = tf.constant(1.0)
+ attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
+ attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0))
+ else:
+ attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.num_hidden_layers
+ # head_mask = tf.constant([0] * self.num_hidden_layers)
+
+ position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
+
+ if inputs_embeds is None:
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
+ inputs_embeds = self.tokens_embed(input_ids, mode="embedding")
+ position_embeds = tf.gather(self.positions_embed, position_ids)
+ if token_type_ids is not None:
+ token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
+ check_embeddings_within_bounds(token_type_ids, self.config.vocab_size, "token_type_ids")
+ token_type_embeds = self.tokens_embed(token_type_ids, mode="embedding")
+ else:
+ token_type_embeds = 0
+ hidden_states = inputs_embeds + position_embeds + token_type_embeds
+ hidden_states = self.drop(hidden_states, training=training)
+
+ output_shape = input_shape + [shape_list(hidden_states)[-1]]
+
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, block in enumerate(self.h):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
+
+ outputs = block(
+ hidden_states,
+ attention_mask,
+ head_mask[i],
+ output_attentions,
+ training=training,
+ )
+ hidden_states = outputs[0]
+ if output_attentions:
+ all_attentions = all_attentions + (outputs[1],)
+
+ hidden_states = tf.reshape(hidden_states, output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if output_attentions:
+ # let the number of heads free (-1) so we can extract attention even after head pruning
+ attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
+ all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
+
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ )
+
+
+class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = OpenAIGPTConfig
+ base_model_prefix = "transformer"
+
+
+@dataclass
+class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput):
+ """
+ Base class for outputs of models predicting if two sentences are consecutive or not.
+
+ Args:
+ logits (`tf.Tensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ mc_logits (`tf.Tensor` of shape `(batch_size, num_choices)`):
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ logits: tf.Tensor | None = None
+ mc_logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+OPENAI_GPT_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`OpenAIGPTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+OPENAI_GPT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`tf.Tensor` or `Numpy array` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
+ OPENAI_GPT_START_DOCSTRING,
+)
+class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> tuple | TFBaseModelOutput:
+ outputs = self.transformer(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+
+
+@add_start_docstrings(
+ """
+ OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """,
+ OPENAI_GPT_START_DOCSTRING,
+)
+class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelingLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
+ # OpenAIGPT does not have past caching features
+ self.supports_xla_generation = False
+
+ def get_output_embeddings(self):
+ return self.get_input_embeddings()
+
+ def set_output_embeddings(self, value):
+ self.set_input_embeddings(value)
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFCausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> tuple | TFCausalLMOutput:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_states = transformer_outputs[0]
+
+ logits = self.transformer.tokens_embed(hidden_states, mode="linear")
+
+ loss = None
+ if labels is not None:
+ # shift labels to the left and cut last logit token
+ shifted_logits = logits[:, :-1]
+ labels = labels[:, 1:]
+ loss = self.hf_compute_loss(labels, shifted_logits)
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFCausalLMOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(self, inputs, **kwargs):
+ return {"input_ids": inputs}
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+
+
+@add_start_docstrings(
+ """
+ OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
+ input sequence).
+ """,
+ OPENAI_GPT_START_DOCSTRING,
+)
+class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ config.num_labels = 1
+ self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
+ self.multiple_choice_head = TFSequenceSummary(
+ config, initializer_range=config.initializer_range, name="multiple_choice_head"
+ )
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ mc_token_ids: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> tuple | TFOpenAIGPTDoubleHeadsModelOutput:
+ r"""
+ mc_token_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
+ 1]`.
+
+ Return:
+
+ Examples:
+
+ ```python
+ >>> import tensorflow as tf
+ >>> from transformers import AutoTokenizer, TFOpenAIGPTDoubleHeadsModel
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
+ >>> model = TFOpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt")
+
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
+ >>> tokenizer.add_special_tokens({"cls_token": "[CLS]"})
+ >>> model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size
+ >>> print(tokenizer.cls_token_id, len(tokenizer)) # The newly token the last token of the vocabulary
+
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
+ >>> encoding = tokenizer(choices, return_tensors="tf")
+ >>> inputs = {k: tf.expand_dims(v, 0) for k, v in encoding.items()}
+ >>> inputs["mc_token_ids"] = tf.constant(
+ ... [inputs["input_ids"].shape[-1] - 1, inputs["input_ids"].shape[-1] - 1]
+ ... )[
+ ... None, :
+ ... ] # Batch size 1
+ >>> outputs = model(inputs)
+ >>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
+ ```"""
+
+ if input_ids is not None:
+ input_shapes = shape_list(input_ids)
+ else:
+ input_shapes = shape_list(inputs_embeds)[:-1]
+
+ seq_length = input_shapes[-1]
+ flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
+ flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
+ flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
+ flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
+ transformer_outputs = self.transformer(
+ flat_input_ids,
+ flat_attention_mask,
+ flat_token_type_ids,
+ flat_position_ids,
+ head_mask,
+ inputs_embeds,
+ output_attentions,
+ output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_states = transformer_outputs[0]
+ hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
+ if return_dict and output_hidden_states:
+ # We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the
+ # input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged)
+ all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,)
+ else:
+ all_hidden_states = None
+ lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
+ mc_logits = tf.squeeze(mc_logits, axis=-1)
+
+ if not return_dict:
+ return (lm_logits, mc_logits) + transformer_outputs[1:]
+
+ return TFOpenAIGPTDoubleHeadsModelOutput(
+ logits=lm_logits,
+ mc_logits=mc_logits,
+ hidden_states=all_hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @property
+ def input_signature(self):
+ return {
+ "input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
+ "attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
+ "mc_token_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
+ }
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+ if getattr(self, "multiple_choice_head", None) is not None:
+ with tf.name_scope(self.multiple_choice_head.name):
+ self.multiple_choice_head.build(None)
+
+
+@add_start_docstrings(
+ """
+ The OpenAI GPT Model transformer with a sequence classification head on top (linear layer).
+
+ [`TFOpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal
+ models (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ OPENAI_GPT_START_DOCSTRING,
+)
+class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+ self.score = keras.layers.Dense(
+ config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="score",
+ use_bias=False,
+ )
+ self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> tuple | TFSequenceClassifierOutput:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+ logits_shape = shape_list(logits)
+ batch_size = logits_shape[0]
+
+ if self.config.pad_token_id is None:
+ last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
+ else:
+ if input_ids is not None:
+ token_indices = tf.range(shape_list(input_ids)[-1])
+ non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
+ last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
+ else:
+ last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+ loss = None
+
+ pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
+
+ if labels is not None:
+ if self.config.pad_token_id is None and logits_shape[0] != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+
+ loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
+
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss,
+ logits=pooled_logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "score", None) is not None:
+ with tf.name_scope(self.score.name):
+ self.score.build([None, None, self.config.n_embd])
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+
+
+__all__ = [
+ "TFOpenAIGPTDoubleHeadsModel",
+ "TFOpenAIGPTForSequenceClassification",
+ "TFOpenAIGPTLMHeadModel",
+ "TFOpenAIGPTMainLayer",
+ "TFOpenAIGPTModel",
+ "TFOpenAIGPTPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/tokenization_openai.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/tokenization_openai.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a9184cc395cab5beabfea30ef11eab928aca89c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/tokenization_openai.py
@@ -0,0 +1,396 @@
+# coding=utf-8
+# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for OpenAI GPT."""
+
+import json
+import os
+import re
+import unicodedata
+from typing import Optional
+
+from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+ "vocab_file": "vocab.json",
+ "merges_file": "merges.txt",
+}
+
+
+# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
+class BasicTokenizer:
+ """
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
+
+ Args:
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ do_split_on_punc (`bool`, *optional*, defaults to `True`):
+ In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
+ the full context of the words, such as contractions.
+ """
+
+ def __init__(
+ self,
+ do_lower_case=True,
+ never_split=None,
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ do_split_on_punc=True,
+ ):
+ if never_split is None:
+ never_split = []
+ self.do_lower_case = do_lower_case
+ self.never_split = set(never_split)
+ self.tokenize_chinese_chars = tokenize_chinese_chars
+ self.strip_accents = strip_accents
+ self.do_split_on_punc = do_split_on_punc
+
+ def tokenize(self, text, never_split=None):
+ """
+ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
+
+ Args:
+ never_split (`List[str]`, *optional*)
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
+ """
+ # union() returns a new set by concatenating the two sets.
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
+ text = self._clean_text(text)
+
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ if self.tokenize_chinese_chars:
+ text = self._tokenize_chinese_chars(text)
+ # prevents treating the same character with different unicode codepoints as different characters
+ unicode_normalized_text = unicodedata.normalize("NFC", text)
+ orig_tokens = whitespace_tokenize(unicode_normalized_text)
+ split_tokens = []
+ for token in orig_tokens:
+ if token not in never_split:
+ if self.do_lower_case:
+ token = token.lower()
+ if self.strip_accents is not False:
+ token = self._run_strip_accents(token)
+ elif self.strip_accents:
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text, never_split=None):
+ """Splits punctuation on a piece of text."""
+ if not self.do_split_on_punc or (never_split is not None and text in never_split):
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if (
+ (cp >= 0x4E00 and cp <= 0x9FFF)
+ or (cp >= 0x3400 and cp <= 0x4DBF)
+ or (cp >= 0x20000 and cp <= 0x2A6DF)
+ or (cp >= 0x2A700 and cp <= 0x2B73F)
+ or (cp >= 0x2B740 and cp <= 0x2B81F)
+ or (cp >= 0x2B820 and cp <= 0x2CEAF)
+ or (cp >= 0xF900 and cp <= 0xFAFF)
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
+ ):
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
+ strings)
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def text_standardize(text):
+ """
+ fixes some issues the spacy tokenizer had on books corpus also does some whitespace standardization
+ """
+ text = text.replace("—", "-")
+ text = text.replace("–", "-")
+ text = text.replace("―", "-")
+ text = text.replace("…", "...")
+ text = text.replace("´", "'")
+ text = re.sub(r"""(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)""", r" \1 ", text)
+ text = re.sub(r"\s*\n\s*", " \n ", text)
+ text = re.sub(r"[^\S\n]+", " ", text)
+ return text.strip()
+
+
+class OpenAIGPTTokenizer(PreTrainedTokenizer):
+ """
+ Construct a GPT Tokenizer. Based on Byte-Pair-Encoding with the following peculiarities:
+
+ - lowercases all inputs,
+ - uses `SpaCy` tokenizer and `ftfy` for pre-BPE tokenization if they are installed, fallback to BERT's
+ `BasicTokenizer` if not.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(self, vocab_file, merges_file, unk_token="", **kwargs):
+ try:
+ import ftfy
+ from spacy.lang.en import English
+
+ _nlp = English()
+ self.nlp = _nlp.tokenizer
+ self.fix_text = ftfy.fix_text
+ except ImportError:
+ logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
+ self.nlp = BasicTokenizer(do_lower_case=True)
+ self.fix_text = None
+
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ merges = merges_handle.read().split("\n")[1:-1]
+ merges = [tuple(merge.split()) for merge in merges]
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {}
+
+ super().__init__(unk_token=unk_token, **kwargs)
+
+ @property
+ def do_lower_case(self):
+ return True
+
+ @property
+ def vocab_size(self):
+ return len(self.encoder)
+
+ def get_vocab(self):
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ def bpe(self, token):
+ word = tuple(token[:-1]) + (token[-1] + "",)
+ if token in self.cache:
+ return self.cache[token]
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ""
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ if word == "\n ":
+ word = "\n"
+ self.cache[token] = word
+ return word
+
+ def _tokenize(self, text):
+ """Tokenize a string."""
+ split_tokens = []
+ if self.fix_text is None:
+ # Using BERT's BasicTokenizer
+ text = self.nlp.tokenize(text)
+ for token in text:
+ split_tokens.extend(list(self.bpe(token).split(" ")))
+ else:
+ # Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
+ text = self.nlp(text_standardize(self.fix_text(text)))
+ for token in text:
+ split_tokens.extend(list(self.bpe(token.text.lower()).split(" ")))
+ return split_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an id in a token (BPE) using the vocab."""
+ return self.decoder.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ out_string = "".join(tokens).replace("", " ").strip()
+ return out_string
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ merge_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+ )
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write("#version: 0.2\n")
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!"
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return vocab_file, merge_file
+
+
+__all__ = ["OpenAIGPTTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/tokenization_openai_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/tokenization_openai_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..83edf5eafa9468347c4b1c3b6bfab9d0e7759ec3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/openai/tokenization_openai_fast.py
@@ -0,0 +1,66 @@
+# coding=utf-8
+# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Tokenization classes for OpenAI GPT."""
+
+from typing import Optional
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+from .tokenization_openai import OpenAIGPTTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
+
+
+class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" GPT Tokenizer (backed by HuggingFace's *tokenizers* library). Based on Byte-Pair-Encoding with
+ the following peculiarities:
+
+ - lower case all inputs
+ - uses BERT's BasicTokenizer for pre-BPE tokenization
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = OpenAIGPTTokenizer
+
+ def __init__(self, vocab_file=None, merges_file=None, tokenizer_file=None, unk_token="", **kwargs):
+ super().__init__(vocab_file, merges_file, tokenizer_file=tokenizer_file, unk_token=unk_token, **kwargs)
+
+ @property
+ def do_lower_case(self):
+ return True
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+
+__all__ = ["OpenAIGPTTokenizerFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..902e015b055f816e59bb2a69ccb0357d5b05cdb1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/__init__.py
@@ -0,0 +1,32 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_ovis2 import *
+ from .image_processing_ovis2 import *
+ from .image_processing_ovis2_fast import *
+ from .modeling_ovis2 import *
+ from .processing_ovis2 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/configuration_ovis2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/configuration_ovis2.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd2a1c0af4d622f3aa878739be942d27e06e2785
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/configuration_ovis2.py
@@ -0,0 +1,182 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+from ..qwen2.configuration_qwen2 import Qwen2Config
+
+
+class Ovis2VisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Ovis2VisionModel`]. It is used to instantiate a
+ Ovis2VisionModel model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of Ovis2.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 2816):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 24):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input images.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the RMSNorm layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ qkv_bias (`bool`, *optional*, defaults to `False`):
+ Whether to add a learnable bias to the query, key, and value sequences at each attention head.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to add a learnable bias to the MLP layers.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ vocab_size (`int`, *optional*, defaults to 16384):
+ Vocabulary size of the Vision Transformer.
+ hidden_stride (`int`, *optional*, defaults to 1):
+ The stride of the hidden layer in the Vision Transformer.
+ num_visual_indicator_tokens (`int`, *optional*, defaults to 5):
+ Number of visual indicator tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated normal initializer for initializing all weight matrices.
+ tokenize_function (`str`, *optional*, defaults to `"softmax"`):
+ The function used to tokenize the visual indicator tokens.
+ """
+
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size: int = 1024,
+ intermediate_size: int = 2816,
+ num_hidden_layers: int = 24,
+ num_attention_heads: int = 8,
+ num_channels: int = 3,
+ image_size: int = 224,
+ patch_size: int = 14,
+ rms_norm_eps: float = 1e-5,
+ attention_dropout: float = 0.0,
+ qkv_bias: bool = False,
+ mlp_bias: bool = False,
+ hidden_act="silu",
+ vocab_size=16384,
+ hidden_stride=1,
+ num_visual_indicator_tokens=5,
+ initializer_range=0.02,
+ tokenize_function="softmax",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+
+ self.attention_dropout = attention_dropout
+ self.hidden_act = hidden_act
+ self.qkv_bias = qkv_bias
+ self.mlp_bias = mlp_bias
+ self.rms_norm_eps = rms_norm_eps
+ self.vocab_size = vocab_size
+ self.hidden_stride = hidden_stride
+ self.num_visual_indicator_tokens = num_visual_indicator_tokens
+ self.tokenize_function = tokenize_function
+ self.initializer_range = initializer_range
+
+
+class Ovis2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Ovis2ForConditionalGeneration`]. It is used to instantiate a
+ Ovis2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of Ovis2.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ e.g. [thisisiron/Ovis2-1B-hf](https://huggingface.co/thisisiron/Ovis2-1B-hf)
+
+ Args:
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Ovis2VisionConfig`):
+ The config object or dictionary of the vision backbone.
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
+ The config object or dictionary of the text backbone.
+ image_token_id (`int`, *optional*, defaults to 151665):
+ The image token id to encode the image prompt.
+ visual_indicator_token_ids (`List[int]`, *optional*, defaults to `[151666, 151667, 151668, 151669, 151670]`):
+ The visual indicator token ids to encode the image prompt.
+ vocab_size (`int`, *optional*, defaults to 151643):
+ Vocabulary size of the text model.
+ hidden_size (`int`, *optional*, defaults to 1536):
+ Dimensionality of the encoder layers and the pooler layer.
+
+ ```python
+ >>> from transformers import Ovis2ForConditionalGeneration, Ovis2Config
+
+ >>> # Initializing a Ovis2 style configuration
+ >>> configuration = Ovis2Config()
+
+ >>> # Initializing a model from the Ovis2-2B style configuration
+ >>> model = Ovis2ForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "ovis2"
+ sub_configs = {"text_config": Qwen2Config, "vision_config": Ovis2VisionConfig}
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ image_token_id=151665,
+ visual_indicator_token_ids=[151666, 151667, 151668, 151669, 151670],
+ vocab_size=151643,
+ hidden_size=1536,
+ **kwargs,
+ ):
+ if isinstance(vision_config, dict):
+ self.vision_config = Ovis2VisionConfig(**vision_config)
+ elif isinstance(vision_config, Ovis2VisionConfig):
+ self.vision_config = vision_config
+ if vision_config is None:
+ self.vision_config = Ovis2VisionConfig(num_visual_indicator_tokens=len(visual_indicator_token_ids))
+
+ if isinstance(text_config, dict):
+ self.text_config = Qwen2Config(**text_config)
+ elif isinstance(text_config, Qwen2Config):
+ self.text_config = text_config
+ elif text_config is None:
+ self.text_config = Qwen2Config()
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.image_token_id = image_token_id
+ self.visual_indicator_token_ids = visual_indicator_token_ids
+ super().__init__(**kwargs)
+
+
+__all__ = ["Ovis2VisionConfig", "Ovis2Config"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/image_processing_ovis2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/image_processing_ovis2.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c0be26d374a125f4945dbc655e2d1d8753b3abe
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/image_processing_ovis2.py
@@ -0,0 +1,574 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import lru_cache
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+# Similar to image_processing_mllama.get_all_supported_aspect_ratios
+@lru_cache(maxsize=10)
+def get_all_supported_aspect_ratios(min_image_tiles: int, max_image_tiles: int) -> list[tuple[int, int]]:
+ """
+ Computes all allowed aspect ratios for a given minimum and maximum number of input tiles.
+
+ This function calculates all possible arrangements of tiles that can be formed
+ within the constraint of the minimum and maximum number of tiles. Each arrangement is
+ represented by its aspect ratio (width/height) and the corresponding tile configuration.
+
+ Args:
+ min_image_tiles (`int`):
+ The minimum number of tiles allowed.
+ max_image_tiles (`int`):
+ The maximum number of tiles allowed.
+
+ Returns:
+ `List[Tuple[int, int]]`: A list of tuples, each tuple representing a valid (width, height)
+ configuration in terms of number of tiles.
+
+ Example:
+ >>> get_all_supported_aspect_ratios(1, 4)
+ [(1, 1), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (2, 2), (4, 1)]
+
+ """
+ aspect_ratios = []
+ for width in range(1, max_image_tiles + 1):
+ for height in range(1, max_image_tiles + 1):
+ if width * height <= max_image_tiles and width * height >= min_image_tiles:
+ aspect_ratios.append((width, height))
+
+ aspect_ratios = sorted(aspect_ratios, key=lambda x: x[0] * x[1])
+
+ return aspect_ratios
+
+
+@lru_cache(maxsize=100)
+def get_optimal_tiled_canvas(
+ original_image_size: tuple[int, int],
+ target_tile_size: tuple[int, int],
+ min_image_tiles: int,
+ max_image_tiles: int,
+) -> tuple[int, int]:
+ """
+ Given a minimum and maximum number of tiles, find the canvas with the closest aspect ratio to the
+ original image aspect ratio.
+ In case of tie-breaking condition when two canvases have the same aspect ratio difference, we favor the canvas with
+ more tiles, until the area covered by the tiles is more than twice the target area, in order to avoid unnecessarily
+ excessive tiling.
+ """
+ possible_tile_arrangements = get_all_supported_aspect_ratios(min_image_tiles, max_image_tiles)
+
+ original_height, original_width = original_image_size
+ target_tile_height, target_tile_width = target_tile_size
+ aspect_ratio = original_width / original_height
+ area = original_width * original_height
+
+ # find the grid with the best aspect ratio
+ best_ratio_diff = float("inf")
+ best_grid = (1, 1)
+ for grid in possible_tile_arrangements:
+ grid_aspect_ratio = grid[0] / grid[1]
+ ratio_diff = abs(aspect_ratio - grid_aspect_ratio)
+ if ratio_diff < best_ratio_diff:
+ best_ratio_diff = ratio_diff
+ best_grid = grid
+ elif ratio_diff == best_ratio_diff:
+ # if the aspect ratio difference is the same, we favor the grid with more patches
+ # until the area covered by the patches is more than twice the original image area
+ if area > 0.5 * target_tile_height * target_tile_width * grid[0] * grid[1]:
+ best_grid = grid
+
+ return best_grid
+
+
+def compute_patch_covering_area(left: int, upper: int, right: int, lower: int, side: int) -> float:
+ w = right - left
+ h = lower - upper
+ w, h = max(w, h), min(w, h)
+ if w > side:
+ h = h / w * side
+ w = side
+ return w * h
+
+
+def split_image_into_grid(h: int, w: int, grid: tuple[int, int]) -> list[tuple[int, int, int, int]]:
+ row_height = h // grid[0]
+ col_width = w // grid[1]
+ return [
+ (
+ col * col_width,
+ row * row_height,
+ w if col == grid[1] - 1 else (col + 1) * col_width,
+ h if row == grid[0] - 1 else (row + 1) * row_height,
+ )
+ for row in range(grid[0])
+ for col in range(grid[1])
+ ]
+
+
+@lru_cache(maxsize=100)
+def get_min_tile_covering_grid(
+ image_size: tuple[int, int],
+ target_patch_size: int,
+ max_image_tiles: int,
+ covering_threshold: float = 0.9,
+) -> tuple[int, int]:
+ image_height, image_width = image_size
+ image_area = image_width * image_height
+
+ candidate_tile_grids = get_all_supported_aspect_ratios(1, max_image_tiles)
+ evaluated_grids = []
+ sufficient_covering_grids = []
+
+ for tile_grid in candidate_tile_grids:
+ tile_regions = split_image_into_grid(image_height, image_width, tile_grid)
+ tile_covering_ratio = (
+ sum([compute_patch_covering_area(*region, target_patch_size) for region in tile_regions]) / image_area
+ )
+
+ evaluated_grids.append((tile_grid, tile_covering_ratio))
+ if tile_covering_ratio > covering_threshold:
+ sufficient_covering_grids.append((tile_grid, tile_covering_ratio))
+
+ if sufficient_covering_grids:
+ # Prefer fewer tiles and higher covering ratio
+ return min(sufficient_covering_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0]
+ else:
+ # Fallback: prefer higher covering even if below threshold
+ return min(evaluated_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0]
+
+
+class Ovis2ImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Ovis2 image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ crop_to_patches (`bool`, *optional*, defaults to `False`):
+ Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
+ `preprocess` method.
+ min_patches (`int`, *optional*, defaults to 1):
+ The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
+ max_patches (`int`, *optional*, defaults to 12):
+ The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
+ overridden by the `resample` parameter in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+ `do_rescale` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ use_covering_area_grid (`bool`, *optional*, defaults to `True`):
+ Whether to use the covering area grid to determine the number of patches. Only has an effect if
+ `crop_to_patches` is set to `True`. Can be overridden by the `use_covering_area_grid` parameter in the
+ `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ crop_to_patches: bool = False,
+ min_patches: int = 1,
+ max_patches: int = 12,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: bool = True,
+ use_covering_area_grid: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 384, "width": 384}
+ size = get_size_dict(size, default_to_square=True)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.crop_to_patches = crop_to_patches
+ self.min_patches = min_patches
+ self.max_patches = max_patches
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+ self.do_convert_rgb = do_convert_rgb
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
+ output_size = (size["height"], size["width"])
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ crop_to_patches: Optional[bool] = None,
+ min_patches: Optional[int] = None,
+ max_patches: Optional[int] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ use_covering_area_grid: bool = True,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Controls the size of the image after `resize`. The shortest edge of the image is resized to
+ `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
+ is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
+ edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
+ crop_to_patches (`bool`, *optional*, defaults to `self.crop_to_patches`):
+ Whether to crop the image to patches.
+ min_patches (`int`, *optional*, defaults to `self.min_patches`):
+ The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`.
+ max_patches (`int`, *optional*, defaults to `self.max_patches`):
+ The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to normalize the image by if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ use_covering_area_grid (`bool`, *optional*, defaults to `True`):
+ Whether to use the covering area grid to determine the number of patches. Only has an effect if
+ `crop_to_patches` is set to `True`.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ crop_to_patches = crop_to_patches if crop_to_patches is not None else self.crop_to_patches
+ min_patches = min_patches if min_patches is not None else self.min_patches
+ max_patches = max_patches if max_patches is not None else self.max_patches
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+ # PIL RGBA images are converted to RGB
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if crop_to_patches and max_patches > 1:
+ images = [
+ self.crop_image_to_patches(
+ image,
+ min_patches=min_patches,
+ max_patches=max_patches,
+ patch_size=size,
+ data_format=input_data_format,
+ use_covering_area_grid=use_covering_area_grid,
+ )
+ for image in images
+ ]
+ grids = [grid for _, grid in images]
+ images = [image for images_list, _ in images for image in images_list]
+ else:
+ grids = [(1, 1)] * len(images)
+
+ for i, image in enumerate(images):
+ if do_resize:
+ images[i] = self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_rescale:
+ images[i] = self.rescale(image=images[i], scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ images[i] = self.normalize(
+ image=images[i],
+ mean=image_mean,
+ std=image_std,
+ input_data_format=input_data_format,
+ )
+
+ images[i] = to_channel_dimension_format(images[i], data_format, input_channel_dim=input_data_format)
+
+ encoded_outputs = BatchFeature(data={"pixel_values": images, "grids": grids}, tensor_type=return_tensors)
+
+ return encoded_outputs
+
+ def crop_image_to_patches(
+ self,
+ images: np.ndarray,
+ min_patches: int,
+ max_patches: int,
+ use_covering_area_grid: bool = True,
+ patch_size: Optional[Union[tuple, int, dict]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ covering_threshold: float = 0.9,
+ ):
+ """
+ Crop the image to patches and return a list of cropped images.
+ The number of patches and their grid arrangement are determined by the original image size,
+ the target patch size and the minimum and maximum number of patches.
+ The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
+
+ Args:
+ images (`np.ndarray`):
+ The image to be cropped.
+ min_patches (`int`):
+ The minimum number of patches to be extracted from the image.
+ max_patches (`int`):
+ The maximum number of patches to be extracted from the image.
+ use_covering_area_grid (`bool`, *optional*, defaults to `True`):
+ Whether to use the covering area grid to determine the number of patches.
+ patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*):
+ The size of the output patches.
+ data_format (`ChannelDimension`, *optional*):
+ The format of the image data. If `None`, the format is inferred from the input image.
+ covering_threshold (`float`, *optional*, defaults to `0.9`):
+ The threshold for the covering area grid. If the covering area is less than this value, the grid is
+ considered invalid.
+
+ Returns:
+ List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images.
+ """
+ if data_format is None:
+ data_format = infer_channel_dimension_format(images)
+ images = to_channel_dimension_format(images, ChannelDimension.FIRST, data_format)
+ patch_size_height, patch_size_width = patch_size["height"], patch_size["width"]
+ original_height, original_width = images.shape[-2:]
+
+ if use_covering_area_grid:
+ # Use the original OVIS2 approach: compute the minimal number of tiles that cover at least 90% of the image area
+ num_columns, num_rows = get_min_tile_covering_grid(
+ (original_height, original_width),
+ target_patch_size=patch_size_height, # square patch size
+ max_image_tiles=max_patches,
+ covering_threshold=covering_threshold,
+ )
+ else:
+ # find the closest aspect ratio to the target
+ num_columns, num_rows = get_optimal_tiled_canvas(
+ (original_height, original_width),
+ (patch_size_height, patch_size_width),
+ min_patches,
+ max_patches,
+ )
+
+ # calculate the target width and height
+ target_width = patch_size_width * num_columns
+ target_height = patch_size_height * num_rows
+ num_blocks = num_columns * num_rows
+
+ # resize the image so that each patch is of patch_size
+ resized_image = self.resize(
+ images,
+ {"height": target_height, "width": target_width},
+ data_format=ChannelDimension.FIRST,
+ input_data_format=ChannelDimension.FIRST,
+ )
+
+ # split the image into patches
+ processed_images = []
+ for i in range(num_blocks):
+ column = i % num_columns
+ row = i // num_columns
+ box = (
+ column * patch_size_width,
+ row * patch_size_height,
+ (column + 1) * patch_size_width,
+ (row + 1) * patch_size_height,
+ )
+ # split the image
+ patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]]
+ patch_image = to_channel_dimension_format(patch_image, data_format, ChannelDimension.FIRST)
+ processed_images.append(patch_image)
+
+ if len(processed_images) != 1:
+ thumbnail_img = self.resize(
+ images, patch_size, data_format=data_format, input_data_format=ChannelDimension.FIRST
+ )
+ processed_images.insert(0, thumbnail_img)
+
+ return processed_images, (num_rows, num_columns)
+
+
+__all__ = ["Ovis2ImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/image_processing_ovis2_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/image_processing_ovis2_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..04b79299e9e14f4c69aa6a39c0d51d04e25e79f8
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/image_processing_ovis2_fast.py
@@ -0,0 +1,245 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ auto_docstring,
+)
+from .image_processing_ovis2 import get_min_tile_covering_grid, get_optimal_tiled_canvas
+
+
+class Ovis2ImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ Args:
+ crop_to_patches (`bool`, *optional*, defaults to `False`):
+ Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
+ `preprocess` method.
+ min_patches (`int`, *optional*, defaults to 1):
+ The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
+ max_patches (`int`, *optional*, defaults to 12):
+ The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
+ use_covering_area_grid (`bool`, *optional*, defaults to `True`):
+ Whether to use the covering area grid to determine the number of patches. Only has an effect if
+ `crop_to_patches` is set to `True`. Can be overridden by the `use_covering_area_grid` parameter in the
+ `preprocess` method.
+ """
+
+ crop_to_patches: Optional[bool]
+ min_patches: Optional[int]
+ max_patches: Optional[int]
+ use_covering_area_grid: Optional[bool]
+
+
+@auto_docstring
+class Ovis2ImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ size = {"height": 384, "width": 384}
+ default_to_square = None
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+ crop_to_patches = False
+ min_patches = 1
+ max_patches = 12
+ use_covering_area_grid = True
+ valid_kwargs = Ovis2ImageProcessorKwargs
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[Ovis2ImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def crop_image_to_patches(
+ self,
+ images: "torch.Tensor",
+ min_patches: int,
+ max_patches: int,
+ use_covering_area_grid: bool = True,
+ covering_threshold: float = 0.9,
+ patch_size: Optional[Union[tuple, int, dict]] = None,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ ):
+ """
+ Crop the images to patches and return a list of cropped images.
+ The number of patches and their grid arrangement are determined by the original image size,
+ the target patch size and the minimum and maximum number of patches.
+ The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
+
+ Args:
+ images (`torch.Tensor`):
+ The images to be cropped.
+ min_patches (`int`):
+ The minimum number of patches to be extracted from the image.
+ max_patches (`int`):
+ The maximum number of patches to be extracted from the image.
+ use_covering_area_grid (`bool`, *optional*, defaults to `True`):
+ Whether to use the original OVIS2 approach: compute the minimal number of tiles that cover at least 90%
+ of the image area. If `False`, the closest aspect ratio to the target is used.
+ covering_threshold (`float`, *optional*, defaults to `0.9`):
+ The threshold for the covering area. Only has an effect if `use_covering_area_grid` is set to `True`.
+ patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*):
+ The size of the output patches.
+ The format of the image data. If `None`, the format is inferred from the input image.
+ interpolation (`InterpolationMode`):
+ Resampling filter to use if resizing the image.
+
+ Returns:
+ List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images.
+ """
+ num_image = images.shape[0]
+ patch_size_height, patch_size_width = patch_size.height, patch_size.width
+ original_height, original_width = images.shape[-2:]
+
+ if use_covering_area_grid:
+ # Use the original OVIS2 approach: compute the minimal number of tiles that cover at least 90% of the image area
+ num_columns, num_rows = get_min_tile_covering_grid(
+ (original_height, original_width),
+ target_patch_size=patch_size_height, # square patch size
+ max_image_tiles=max_patches,
+ covering_threshold=covering_threshold,
+ )
+ else:
+ # find the closest aspect ratio to the target
+ num_columns, num_rows = get_optimal_tiled_canvas(
+ (original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches
+ )
+
+ # calculate the target width and height
+ target_width = patch_size_width * num_columns
+ target_height = patch_size_height * num_rows
+ num_blocks = num_columns * num_rows
+
+ # resize the image so that each patch is of patch_size
+ resized_image = self.resize(
+ images, SizeDict(height=target_height, width=target_width), interpolation=interpolation
+ )
+ # split the image into patches
+ processed_images = []
+ for i in range(num_blocks):
+ column = i % num_columns
+ row = i // num_columns
+ box = (
+ column * patch_size_width,
+ row * patch_size_height,
+ (column + 1) * patch_size_width,
+ (row + 1) * patch_size_height,
+ )
+ # split the image
+ patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]]
+ processed_images.append(patch_image)
+
+ if len(processed_images) != 1:
+ thumbnail_img = self.resize(images, patch_size, interpolation=interpolation)
+ processed_images.insert(0, thumbnail_img)
+
+ processed_images = torch.stack(processed_images, dim=0).transpose(0, 1).contiguous()
+ grid = [[num_rows, num_columns] for _ in range(num_image)]
+
+ return processed_images, grid
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ crop_to_patches: bool,
+ min_patches: int,
+ max_patches: int,
+ use_covering_area_grid: bool,
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ if crop_to_patches and max_patches > 1:
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ grids = {}
+ for shape, stacked_images in grouped_images.items():
+ stacked_images, grid = self.crop_image_to_patches(
+ stacked_images,
+ min_patches,
+ max_patches,
+ patch_size=size,
+ use_covering_area_grid=use_covering_area_grid,
+ interpolation=interpolation,
+ )
+ processed_images_grouped[shape] = stacked_images
+ grids[shape] = grid
+ images = reorder_images(processed_images_grouped, grouped_images_index)
+ images = [image for images_list in images for image in images_list]
+ grids = reorder_images(grids, grouped_images_index)
+ else:
+ grids = [[1, 1] for _ in range(len(images))]
+
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_center_crop:
+ stacked_images = self.center_crop(stacked_images, crop_size)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+ return BatchFeature(data={"pixel_values": processed_images, "grids": grids}, tensor_type=return_tensors)
+
+
+__all__ = ["Ovis2ImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/modeling_ovis2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/modeling_ovis2.py
new file mode 100644
index 0000000000000000000000000000000000000000..75ff19ab9d14fc34b69a0759947373531a4f91f7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/modeling_ovis2.py
@@ -0,0 +1,828 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/ovis2/modular_ovis2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_ovis2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
+from ..auto import AutoModel
+from .configuration_ovis2 import Ovis2Config, Ovis2VisionConfig
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class Ovis2ModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Ovis2 causal language model (or autoregressive) outputs.
+ """
+)
+class Ovis2CausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class Ovis2RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Ovis2RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Ovis2VisionMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class Ovis2VisionEmbeddings(nn.Module):
+ def __init__(self, config: Ovis2VisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+ self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+ embeddings = self.rms_norm(embeddings)
+
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+
+ return embeddings
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Ovis2VisionAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+ self.is_causal = False
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, seq_length, embed_dim = hidden_states.shape
+
+ queries = self.q_proj(hidden_states)
+ keys = self.k_proj(hidden_states)
+ values = self.v_proj(hidden_states)
+
+ queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ queries,
+ keys,
+ values,
+ attention_mask,
+ is_causal=self.is_causal,
+ scaling=self.scale,
+ dropout=0.0 if not self.training else self.dropout,
+ )
+
+ attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class Ovis2MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class Ovis2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+ self.is_causal = False
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, seq_length, embed_dim = hidden_states.shape
+
+ queries = self.q_proj(hidden_states)
+ keys = self.k_proj(hidden_states)
+ values = self.v_proj(hidden_states)
+
+ queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ queries,
+ keys,
+ values,
+ attention_mask,
+ is_causal=self.is_causal,
+ scaling=self.scale,
+ dropout=0.0 if not self.training else self.dropout,
+ )
+
+ attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class Ovis2VisionEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Ovis2VisionConfig):
+ super().__init__()
+ self.attention = Ovis2Attention(config)
+ self.ffn = Ovis2MLP(config)
+ self.rms_norm1 = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
+ self.rms_norm2 = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ norm_hidden_states = self.rms_norm1(hidden_states)
+ attn_output, _ = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask, **kwargs)
+
+ hidden_states = hidden_states + attn_output
+ norm_hidden_states = self.rms_norm2(hidden_states)
+ mlp_output = self.ffn(norm_hidden_states)
+
+ hidden_states = hidden_states + mlp_output
+ return hidden_states
+
+
+class Ovis2VisionEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`Ovis2VisionEncoderLayer`].
+
+ Args:
+ config: Ovis2VisionConfig
+ """
+
+ def __init__(self, config: Ovis2VisionConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([Ovis2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ # Ignore copy
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutput:
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(hidden_states, attention_mask, **kwargs)
+
+ return BaseModelOutput(last_hidden_state=hidden_states)
+
+
+class Ovis2VisionTransformer(nn.Module):
+ def __init__(self, config: Ovis2VisionConfig):
+ super().__init__()
+ self.config = config
+ self.embeddings = Ovis2VisionEmbeddings(config)
+ self.encoder = Ovis2VisionEncoder(config)
+ self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
+ self.gradient_checkpointing = False
+
+ @can_return_tuple
+ def forward(
+ self,
+ pixel_values,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ hidden_states = self.embeddings(pixel_values)
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.rms_norm(last_hidden_state)
+
+ return BaseModelOutput(last_hidden_state=last_hidden_state)
+
+
+class Ovis2VisualEmbeddingTable(nn.Embedding):
+ def forward(self, visual_tokens: torch.Tensor) -> torch.Tensor:
+ if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
+ return super().forward(visual_tokens)
+ return torch.matmul(visual_tokens, self.weight)
+
+
+class Ovis2PreTrainedModel(PreTrainedModel):
+ config: Ovis2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Ovis2VisionAttention"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_cache_class = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+
+
+def hard_softmax(logits: torch.Tensor, dim: int):
+ y_soft = logits.softmax(dim)
+ # Straight through.
+ index = y_soft.max(dim, keepdim=True)[1]
+ y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
+ ret = y_hard - y_soft.detach() + y_soft
+
+ return ret
+
+
+class Ovis2VisionModel(Ovis2PreTrainedModel):
+ config: Ovis2VisionConfig
+
+ def __init__(self, config: Ovis2VisionConfig):
+ super().__init__(config)
+ self.config = config
+ self.transformer = Ovis2VisionTransformer(config)
+ self.num_visual_indicator_tokens = config.num_visual_indicator_tokens
+ self.vocab_size = config.vocab_size
+ self.head_linear = nn.Linear(
+ config.hidden_size * config.hidden_stride * config.hidden_stride,
+ self.vocab_size - self.num_visual_indicator_tokens,
+ bias=False,
+ )
+ self.head_norm = nn.LayerNorm(self.vocab_size - self.num_visual_indicator_tokens)
+
+ def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
+ outputs = self.transformer(pixel_values, **kwargs)
+ last_hidden_state = outputs[0]
+ if self.config.hidden_stride > 1:
+ num_images, seq_len, hidden_dim = last_hidden_state.shape
+ hidden_stride = self.config.hidden_stride
+
+ sqrt_l = int(math.sqrt(seq_len))
+ if sqrt_l * sqrt_l != seq_len:
+ raise ValueError("Token sequence length must be a perfect square")
+
+ pad_size = (hidden_stride - (sqrt_l % hidden_stride)) % hidden_stride
+ last_hidden_state = nn.functional.pad(last_hidden_state, (0, 0, 0, pad_size, 0, pad_size), "constant", 0)
+ sqrt_l += pad_size
+
+ last_hidden_state = last_hidden_state.reshape(
+ num_images, sqrt_l // hidden_stride, hidden_stride, sqrt_l // hidden_stride, hidden_stride, hidden_dim
+ )
+ last_hidden_state = last_hidden_state.permute(0, 1, 3, 2, 4, 5)
+ last_hidden_state = last_hidden_state.reshape(
+ num_images, -1, hidden_stride * hidden_stride * hidden_dim
+ ) # (n, (sqrt_l//hs)^2, hs^2*d)
+
+ logits = self.head_linear(last_hidden_state)
+ logits = self.head_norm(logits)
+
+ if self.config.tokenize_function == "gumbel_argmax":
+ prob_token = nn.functional.gumbel_softmax(logits, dim=-1, hard=True)
+ elif self.config.tokenize_function == "st_argmax":
+ prob_token = hard_softmax(logits, dim=-1)
+ elif self.config.tokenize_function == "softmax":
+ prob_token = nn.functional.softmax(logits, dim=-1)
+
+ return prob_token
+
+
+@auto_docstring(
+ custom_intro="""
+ The Ovis2 model which consists of a vision backbone and a language model, without a language modeling head.
+ """
+)
+class Ovis2Model(Ovis2PreTrainedModel):
+ _checkpoint_conversion_mapping = {}
+
+ def __init__(self, config: Ovis2Config):
+ super().__init__(config)
+ self.vision_tower = Ovis2VisionModel(config.vision_config)
+ self.language_model = AutoModel.from_config(config.text_config)
+ self.visual_embeddings_table = Ovis2VisualEmbeddingTable(config.vision_config.vocab_size, config.hidden_size)
+
+ self.visual_vocab_size = config.vision_config.vocab_size
+ self.vocab_size = config.vocab_size
+ self.visual_indicator_token_ids = config.visual_indicator_token_ids
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
+ The tensors corresponding to the input images.
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ image_features = self.vision_tower(pixel_values)
+ batch_size, img_seq_len, _ = image_features.shape
+ padding_tensor = torch.zeros(
+ (batch_size, img_seq_len, self.vision_tower.num_visual_indicator_tokens),
+ dtype=image_features.dtype,
+ device=image_features.device,
+ requires_grad=False,
+ layout=image_features.layout,
+ )
+ image_features = torch.cat([image_features, padding_tensor], dim=2)
+ image_features = self.visual_embeddings_table(image_features)
+
+ visual_indicator = torch.arange(
+ self.visual_vocab_size - self.vision_tower.num_visual_indicator_tokens,
+ self.visual_vocab_size,
+ dtype=torch.long,
+ ).to(image_features.device)
+ visual_indicator_features = self.visual_embeddings_table(visual_indicator)
+
+ return image_features, visual_indicator_features
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ return special_image_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> Union[tuple, Ovis2ModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features, visual_indicator_features = self.get_image_features(pixel_values=pixel_values)
+
+ special_image_mask = self.get_placeholder_mask(
+ input_ids,
+ inputs_embeds=inputs_embeds,
+ image_features=image_features,
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ for i, visual_indicator_id in enumerate(self.visual_indicator_token_ids):
+ if input_ids is None:
+ mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(visual_indicator_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ mask = mask.all(-1)
+ else:
+ mask = (input_ids == visual_indicator_id).to(inputs_embeds.device)
+
+ if mask.any():
+ inputs_embeds[mask] = (
+ visual_indicator_features[i]
+ .expand_as(inputs_embeds[mask])
+ .to(inputs_embeds.device, inputs_embeds.dtype)
+ )
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ return Ovis2ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+@auto_docstring
+class Ovis2ForConditionalGeneration(Ovis2PreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: Ovis2Config):
+ super().__init__(config)
+ self.model = Ovis2Model(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_image_features(self, pixel_values: torch.FloatTensor):
+ return self.model.get_image_features(pixel_values=pixel_values)
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ raise AttributeError("Not needed for Ovis2")
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> Union[tuple, Ovis2CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Ovis2ForConditionalGeneration
+
+ >>> model = Ovis2ForConditionalGeneration.from_pretrained("thisisiron/Ovis2-2B-hf")
+ >>> processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-2B-hf")
+
+ >>> prompt = "<|im_start|>user\n\nDescribe the image.<|im_end|>\n<|im_start|>assistant\n"
+ >>> url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
+ "user\n\nDescribe the image.\nassistant\nThe image features a brown dog standing on a wooden floor, looking up with"
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return Ovis2CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+
+ return model_inputs
+
+
+__all__ = ["Ovis2PreTrainedModel", "Ovis2Model", "Ovis2ForConditionalGeneration"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/modular_ovis2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/modular_ovis2.py
new file mode 100644
index 0000000000000000000000000000000000000000..09ce53703a15851f85a3dac6a10ce00f9522cf75
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/modular_ovis2.py
@@ -0,0 +1,434 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ..aimv2.modeling_aimv2 import Aimv2Attention, Aimv2EncoderLayer
+from ..auto import AutoModel
+from ..llama.modeling_llama import LlamaMLP, LlamaRMSNorm
+from ..llava.modeling_llava import LlavaForConditionalGeneration, LlavaModel
+from ..llava_next.modeling_llava_next import LlavaNextCausalLMOutputWithPast, LlavaNextModelOutputWithPast
+from ..siglip.modeling_siglip import SiglipEncoder, SiglipVisionEmbeddings
+from .configuration_ovis2 import Ovis2Config, Ovis2VisionConfig
+
+
+def hard_softmax(logits: torch.Tensor, dim: int):
+ y_soft = logits.softmax(dim)
+ # Straight through.
+ index = y_soft.max(dim, keepdim=True)[1]
+ y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
+ ret = y_hard - y_soft.detach() + y_soft
+
+ return ret
+
+
+class Ovis2ModelOutputWithPast(LlavaNextModelOutputWithPast):
+ pass
+
+
+class Ovis2CausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast):
+ pass
+
+
+class Ovis2RMSNorm(LlamaRMSNorm):
+ pass
+
+
+class Ovis2VisionMLP(LlamaMLP):
+ pass
+
+
+class Ovis2VisionEmbeddings(SiglipVisionEmbeddings):
+ def __init__(self, config: Ovis2VisionConfig):
+ super().__init__(config)
+ self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
+
+ def interpolate_pos_encoding(self):
+ raise NotImplementedError("Not needed for Ovis2")
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+ embeddings = self.rms_norm(embeddings)
+
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+
+ return embeddings
+
+
+class Ovis2VisionAttention(Aimv2Attention):
+ pass
+
+
+class Ovis2VisionEncoderLayer(Aimv2EncoderLayer):
+ pass
+
+
+class Ovis2VisionEncoder(SiglipEncoder):
+ def __init__(self, config: Ovis2VisionConfig):
+ super().__init__(config)
+ self.layers = nn.ModuleList([Ovis2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutput:
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(hidden_states, attention_mask, **kwargs)
+
+ return BaseModelOutput(last_hidden_state=hidden_states)
+
+
+class Ovis2VisionTransformer(nn.Module):
+ def __init__(self, config: Ovis2VisionConfig):
+ super().__init__()
+ self.config = config
+ self.embeddings = Ovis2VisionEmbeddings(config)
+ self.encoder = Ovis2VisionEncoder(config)
+ self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
+ self.gradient_checkpointing = False
+
+ @can_return_tuple
+ def forward(
+ self,
+ pixel_values,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ hidden_states = self.embeddings(pixel_values)
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.rms_norm(last_hidden_state)
+
+ return BaseModelOutput(last_hidden_state=last_hidden_state)
+
+
+class Ovis2VisualEmbeddingTable(nn.Embedding):
+ def forward(self, visual_tokens: torch.Tensor) -> torch.Tensor:
+ if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
+ return super().forward(visual_tokens)
+ return torch.matmul(visual_tokens, self.weight)
+
+
+class Ovis2PreTrainedModel(PreTrainedModel):
+ config: Ovis2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Ovis2VisionAttention"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_cache_class = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+
+
+class Ovis2VisionModel(Ovis2PreTrainedModel):
+ config: Ovis2VisionConfig
+
+ def __init__(self, config: Ovis2VisionConfig):
+ super().__init__(config)
+ self.config = config
+ self.transformer = Ovis2VisionTransformer(config)
+ self.num_visual_indicator_tokens = config.num_visual_indicator_tokens
+ self.vocab_size = config.vocab_size
+ self.head_linear = nn.Linear(
+ config.hidden_size * config.hidden_stride * config.hidden_stride,
+ self.vocab_size - self.num_visual_indicator_tokens,
+ bias=False,
+ )
+ self.head_norm = nn.LayerNorm(self.vocab_size - self.num_visual_indicator_tokens)
+
+ def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
+ outputs = self.transformer(pixel_values, **kwargs)
+ last_hidden_state = outputs[0]
+ if self.config.hidden_stride > 1:
+ num_images, seq_len, hidden_dim = last_hidden_state.shape
+ hidden_stride = self.config.hidden_stride
+
+ sqrt_l = int(math.sqrt(seq_len))
+ if sqrt_l * sqrt_l != seq_len:
+ raise ValueError("Token sequence length must be a perfect square")
+
+ pad_size = (hidden_stride - (sqrt_l % hidden_stride)) % hidden_stride
+ last_hidden_state = nn.functional.pad(last_hidden_state, (0, 0, 0, pad_size, 0, pad_size), "constant", 0)
+ sqrt_l += pad_size
+
+ last_hidden_state = last_hidden_state.reshape(
+ num_images, sqrt_l // hidden_stride, hidden_stride, sqrt_l // hidden_stride, hidden_stride, hidden_dim
+ )
+ last_hidden_state = last_hidden_state.permute(0, 1, 3, 2, 4, 5)
+ last_hidden_state = last_hidden_state.reshape(
+ num_images, -1, hidden_stride * hidden_stride * hidden_dim
+ ) # (n, (sqrt_l//hs)^2, hs^2*d)
+
+ logits = self.head_linear(last_hidden_state)
+ logits = self.head_norm(logits)
+
+ if self.config.tokenize_function == "gumbel_argmax":
+ prob_token = nn.functional.gumbel_softmax(logits, dim=-1, hard=True)
+ elif self.config.tokenize_function == "st_argmax":
+ prob_token = hard_softmax(logits, dim=-1)
+ elif self.config.tokenize_function == "softmax":
+ prob_token = nn.functional.softmax(logits, dim=-1)
+
+ return prob_token
+
+
+class Ovis2Model(LlavaModel):
+ _checkpoint_conversion_mapping = {}
+
+ def __init__(self, config: Ovis2Config):
+ super().__init__(config)
+ self.vision_tower = Ovis2VisionModel(config.vision_config)
+ self.visual_embeddings_table = Ovis2VisualEmbeddingTable(config.vision_config.vocab_size, config.hidden_size)
+
+ self.visual_vocab_size = config.vision_config.vocab_size
+ self.vocab_size = config.vocab_size
+ self.visual_indicator_token_ids = config.visual_indicator_token_ids
+ self.language_model = AutoModel.from_config(config.text_config)
+ del self.multi_modal_projector
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ image_features = self.vision_tower(pixel_values)
+ batch_size, img_seq_len, _ = image_features.shape
+ padding_tensor = torch.zeros(
+ (batch_size, img_seq_len, self.vision_tower.num_visual_indicator_tokens),
+ dtype=image_features.dtype,
+ device=image_features.device,
+ requires_grad=False,
+ layout=image_features.layout,
+ )
+ image_features = torch.cat([image_features, padding_tensor], dim=2)
+ image_features = self.visual_embeddings_table(image_features)
+
+ visual_indicator = torch.arange(
+ self.visual_vocab_size - self.vision_tower.num_visual_indicator_tokens,
+ self.visual_vocab_size,
+ dtype=torch.long,
+ ).to(image_features.device)
+ visual_indicator_features = self.visual_embeddings_table(visual_indicator)
+
+ return image_features, visual_indicator_features
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> Union[tuple, Ovis2ModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features, visual_indicator_features = self.get_image_features(pixel_values=pixel_values)
+
+ special_image_mask = self.get_placeholder_mask(
+ input_ids,
+ inputs_embeds=inputs_embeds,
+ image_features=image_features,
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ for i, visual_indicator_id in enumerate(self.visual_indicator_token_ids):
+ if input_ids is None:
+ mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(visual_indicator_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ mask = mask.all(-1)
+ else:
+ mask = (input_ids == visual_indicator_id).to(inputs_embeds.device)
+
+ if mask.any():
+ inputs_embeds[mask] = (
+ visual_indicator_features[i]
+ .expand_as(inputs_embeds[mask])
+ .to(inputs_embeds.device, inputs_embeds.dtype)
+ )
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ return Ovis2ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+@auto_docstring
+class Ovis2ForConditionalGeneration(LlavaForConditionalGeneration, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+
+ def __init__(self, config: Ovis2Config):
+ super().__init__(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ @property
+ def multi_modal_projector(self):
+ raise AttributeError("Not needed for Ovis2")
+
+ def get_image_features(self, pixel_values: torch.FloatTensor):
+ return self.model.get_image_features(pixel_values=pixel_values)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> Union[tuple, Ovis2CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Ovis2ForConditionalGeneration
+
+ >>> model = Ovis2ForConditionalGeneration.from_pretrained("thisisiron/Ovis2-2B-hf")
+ >>> processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-2B-hf")
+
+ >>> prompt = "<|im_start|>user\n\nDescribe the image.<|im_end|>\n<|im_start|>assistant\n"
+ >>> url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
+ "user\n\nDescribe the image.\nassistant\nThe image features a brown dog standing on a wooden floor, looking up with"
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return Ovis2CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+
+__all__ = ["Ovis2PreTrainedModel", "Ovis2Model", "Ovis2ForConditionalGeneration"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/processing_ovis2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/processing_ovis2.py
new file mode 100644
index 0000000000000000000000000000000000000000..efb79409da2bd3edc5e9a89d0d0cabc3b76c53e0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/ovis2/processing_ovis2.py
@@ -0,0 +1,181 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Ovis2ProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ },
+ "image_kwargs": {},
+ }
+
+
+class Ovis2Processor(ProcessorMixin):
+ r"""
+ Constructs a Ovis2 processor which wraps Ovis2 image processor and a Qwen2 tokenizer into a single processor.
+
+ [`Ovis2Processor`] offers all the functionalities of [`Ovis2VideoProcessor`], [`Ovis2ImageProcessor`] and [`Qwen2TokenizerFast`]. See the
+ [`~Ovis2Processor.__call__`] and [`~Ovis2Processor.decode`] for more information.
+
+ Args:
+ image_processor ([`Ovis2ImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ image_token (`str`, *optional*, defaults to `""`):
+ Special token used to denote image location.
+ image_seq_length (`int`, *optional*, defaults to 256):
+ The number of image tokens to be used for each image in the input.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer=None,
+ chat_template=None,
+ image_token="",
+ image_seq_length=256,
+ **kwargs,
+ ):
+ self.image_seq_length = image_seq_length
+ self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
+ self.image_token_id = (
+ tokenizer.image_token_id
+ if getattr(tokenizer, "image_token_id", None)
+ else tokenizer.convert_tokens_to_ids(self.image_token)
+ )
+ super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ **kwargs: Unpack[Ovis2ProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
+ Ovis2ImageProcessor's [`~Ovis2ImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
+ of the above two methods for more information.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
+ """
+
+ output_kwargs = self._merge_kwargs(
+ Ovis2ProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list) and not isinstance(text[0], str):
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+
+ image_inputs = {}
+
+ if images is not None:
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
+ image_grids = image_inputs.pop("grids").tolist()
+ text = self._expand_image_tokens(text, image_grids)
+
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+ return BatchFeature(data={**text_inputs, **image_inputs})
+
+ def _expand_image_tokens(
+ self,
+ text: list[TextInput],
+ grids: list[list[int]],
+ ):
+ processed_text = []
+ grid_index = 0
+ for sample in text:
+ while "" in sample:
+ grid = grids[grid_index]
+ row, col = grid[0], grid[1]
+ placeholder = f"{'' * self.image_seq_length}"
+ if row * col > 1:
+ for r in range(row):
+ for c in range(col):
+ placeholder += f"{'' * self.image_seq_length}"
+ if c < col - 1:
+ placeholder += ""
+ if r < row - 1:
+ placeholder += ""
+ placeholder += ""
+
+ sample = sample.replace("", placeholder, 1)
+ grid_index += 1
+ processed_text.append(sample)
+ return processed_text
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(tokenizer_input_names) + list(image_processor_input_names)
+
+
+__all__ = ["Ovis2Processor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3d6deae9caf621e03d4042e52bbcd6d28b801ec
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_owlv2 import *
+ from .image_processing_owlv2 import *
+ from .image_processing_owlv2_fast import *
+ from .modeling_owlv2 import *
+ from .processing_owlv2 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/configuration_owlv2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/configuration_owlv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..310a46508b84ac1952008f67a1a0845ad9512e18
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/configuration_owlv2.py
@@ -0,0 +1,283 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""OWLv2 model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from transformers.models.owlvit.configuration_owlvit.OwlViTTextConfig with OwlViT->Owlv2, owlvit-base-patch32->owlv2-base-patch16, owlvit->owlv2, OWL-ViT->OWLv2
+class Owlv2TextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of an [`Owlv2TextModel`]. It is used to instantiate an
+ Owlv2 text encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the Owlv2
+ [google/owlv2-base-patch16](https://huggingface.co/google/owlv2-base-patch16) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 49408):
+ Vocabulary size of the OWLv2 text model. Defines the number of different tokens that can be represented
+ by the `inputs_ids` passed when calling [`Owlv2TextModel`].
+ hidden_size (`int`, *optional*, defaults to 512):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 2048):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ max_position_embeddings (`int`, *optional*, defaults to 16):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 1.0):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+ pad_token_id (`int`, *optional*, defaults to 0):
+ The id of the padding token in the input sequences.
+ bos_token_id (`int`, *optional*, defaults to 49406):
+ The id of the beginning-of-sequence token in the input sequences.
+ eos_token_id (`int`, *optional*, defaults to 49407):
+ The id of the end-of-sequence token in the input sequences.
+
+ Example:
+
+ ```python
+ >>> from transformers import Owlv2TextConfig, Owlv2TextModel
+
+ >>> # Initializing a Owlv2TextModel with google/owlv2-base-patch16 style configuration
+ >>> configuration = Owlv2TextConfig()
+
+ >>> # Initializing a Owlv2TextConfig from the google/owlv2-base-patch16 style configuration
+ >>> model = Owlv2TextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "owlv2_text_model"
+ base_config_key = "text_config"
+
+ def __init__(
+ self,
+ vocab_size=49408,
+ hidden_size=512,
+ intermediate_size=2048,
+ num_hidden_layers=12,
+ num_attention_heads=8,
+ max_position_embeddings=16,
+ hidden_act="quick_gelu",
+ layer_norm_eps=1e-5,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ pad_token_id=0,
+ bos_token_id=49406,
+ eos_token_id=49407,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+
+
+# Copied from transformers.models.owlvit.configuration_owlvit.OwlViTVisionConfig with OwlViT->Owlv2, owlvit-base-patch32->owlv2-base-patch16, owlvit->owlv2, OWL-ViT->OWLv2, 32->16
+class Owlv2VisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of an [`Owlv2VisionModel`]. It is used to instantiate
+ an OWLv2 image encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the OWLv2
+ [google/owlv2-base-patch16](https://huggingface.co/google/owlv2-base-patch16) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input images.
+ image_size (`int`, *optional*, defaults to 768):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 1.0):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+
+ Example:
+
+ ```python
+ >>> from transformers import Owlv2VisionConfig, Owlv2VisionModel
+
+ >>> # Initializing a Owlv2VisionModel with google/owlv2-base-patch16 style configuration
+ >>> configuration = Owlv2VisionConfig()
+
+ >>> # Initializing a Owlv2VisionModel model from the google/owlv2-base-patch16 style configuration
+ >>> model = Owlv2VisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "owlv2_vision_model"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ intermediate_size=3072,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ num_channels=3,
+ image_size=768,
+ patch_size=16,
+ hidden_act="quick_gelu",
+ layer_norm_eps=1e-5,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+
+
+# Copied from transformers.models.owlvit.configuration_owlvit.OwlViTConfig with OwlViT->Owlv2, owlvit-base-patch32->owlv2-base-patch16, owlvit->owlv2, OWL-ViT->OWLv2
+class Owlv2Config(PretrainedConfig):
+ r"""
+ [`Owlv2Config`] is the configuration class to store the configuration of an [`Owlv2Model`]. It is used to
+ instantiate an OWLv2 model according to the specified arguments, defining the text model and vision model
+ configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the OWLv2
+ [google/owlv2-base-patch16](https://huggingface.co/google/owlv2-base-patch16) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`Owlv2TextConfig`].
+ vision_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`Owlv2VisionConfig`].
+ projection_dim (`int`, *optional*, defaults to 512):
+ Dimensionality of text and vision projection layers.
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
+ The initial value of the *logit_scale* parameter. Default is used as per the original OWLv2
+ implementation.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return a dictionary. If `False`, returns a tuple.
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+ """
+
+ model_type = "owlv2"
+ sub_configs = {"text_config": Owlv2TextConfig, "vision_config": Owlv2VisionConfig}
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ projection_dim=512,
+ logit_scale_init_value=2.6592,
+ return_dict=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if text_config is None:
+ text_config = {}
+ logger.info("text_config is None. Initializing the Owlv2TextConfig with default values.")
+
+ if vision_config is None:
+ vision_config = {}
+ logger.info("vision_config is None. initializing the Owlv2VisionConfig with default values.")
+
+ self.text_config = Owlv2TextConfig(**text_config)
+ self.vision_config = Owlv2VisionConfig(**vision_config)
+
+ self.projection_dim = projection_dim
+ self.logit_scale_init_value = logit_scale_init_value
+ self.return_dict = return_dict
+ self.initializer_factor = 1.0
+
+ @classmethod
+ def from_text_vision_configs(cls, text_config: dict, vision_config: dict, **kwargs):
+ r"""
+ Instantiate a [`Owlv2Config`] (or a derived class) from owlv2 text model configuration and owlv2 vision
+ model configuration.
+
+ Returns:
+ [`Owlv2Config`]: An instance of a configuration object
+ """
+ config_dict = {}
+ config_dict["text_config"] = text_config
+ config_dict["vision_config"] = vision_config
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+__all__ = ["Owlv2Config", "Owlv2TextConfig", "Owlv2VisionConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/image_processing_owlv2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/image_processing_owlv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..64399d433f5e46157d344aa2a18ad17879afbdda
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/image_processing_owlv2.py
@@ -0,0 +1,635 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for OWLv2."""
+
+import warnings
+from typing import TYPE_CHECKING, Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ center_to_corners_format,
+ pad,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import (
+ TensorType,
+ filter_out_non_signature_kwargs,
+ is_scipy_available,
+ is_torch_available,
+ is_vision_available,
+ logging,
+ requires_backends,
+)
+
+
+if is_torch_available():
+ import torch
+
+
+if is_vision_available():
+ import PIL
+
+if is_scipy_available():
+ from scipy import ndimage as ndi
+
+if TYPE_CHECKING:
+ from .modeling_owlv2 import Owlv2ObjectDetectionOutput
+
+logger = logging.get_logger(__name__)
+
+
+def _scale_boxes(boxes, target_sizes):
+ """
+ Scale batch of bounding boxes to the target sizes.
+
+ Args:
+ boxes (`torch.Tensor` of shape `(batch_size, num_boxes, 4)`):
+ Bounding boxes to scale. Each box is expected to be in (x1, y1, x2, y2) format.
+ target_sizes (`list[tuple[int, int]]` or `torch.Tensor` of shape `(batch_size, 2)`):
+ Target sizes to scale the boxes to. Each target size is expected to be in (height, width) format.
+
+ Returns:
+ `torch.Tensor` of shape `(batch_size, num_boxes, 4)`: Scaled bounding boxes.
+ """
+
+ if isinstance(target_sizes, (list, tuple)):
+ image_height = torch.tensor([i[0] for i in target_sizes])
+ image_width = torch.tensor([i[1] for i in target_sizes])
+ elif isinstance(target_sizes, torch.Tensor):
+ image_height, image_width = target_sizes.unbind(1)
+ else:
+ raise TypeError("`target_sizes` must be a list, tuple or torch.Tensor")
+
+ # for owlv2 image is padded to max size unlike owlvit, that's why we have to scale boxes to max size
+ max_size = torch.max(image_height, image_width)
+
+ scale_factor = torch.stack([max_size, max_size, max_size, max_size], dim=1)
+ scale_factor = scale_factor.unsqueeze(1).to(boxes.device)
+ boxes = boxes * scale_factor
+ return boxes
+
+
+# Copied from transformers.models.owlvit.image_processing_owlvit._upcast
+def _upcast(t):
+ # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
+ if t.is_floating_point():
+ return t if t.dtype in (torch.float32, torch.float64) else t.float()
+ else:
+ return t if t.dtype in (torch.int32, torch.int64) else t.int()
+
+
+# Copied from transformers.models.owlvit.image_processing_owlvit.box_area
+def box_area(boxes):
+ """
+ Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
+
+ Args:
+ boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
+ Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
+ < x2` and `0 <= y1 < y2`.
+ Returns:
+ `torch.FloatTensor`: a tensor containing the area for each box.
+ """
+ boxes = _upcast(boxes)
+ return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
+
+
+# Copied from transformers.models.owlvit.image_processing_owlvit.box_iou
+def box_iou(boxes1, boxes2):
+ area1 = box_area(boxes1)
+ area2 = box_area(boxes2)
+
+ left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
+ right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
+
+ width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
+ inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
+
+ union = area1[:, None] + area2 - inter
+
+ iou = inter / union
+ return iou, union
+
+
+def _preprocess_resize_output_shape(image, output_shape):
+ """Validate resize output shape according to input image.
+
+ Args:
+ image (`np.ndarray`):
+ Image to be resized.
+ output_shape (`iterable`):
+ Size of the generated output image `(rows, cols[, ...][, dim])`. If `dim` is not provided, the number of
+ channels is preserved.
+
+ Returns
+ image (`np.ndarray`):
+ The input image, but with additional singleton dimensions appended in the case where `len(output_shape) >
+ input.ndim`.
+ output_shape (`Tuple`):
+ The output shape converted to tuple.
+
+ Raises ------ ValueError:
+ If output_shape length is smaller than the image number of dimensions.
+
+ Notes ----- The input image is reshaped if its number of dimensions is not equal to output_shape_length.
+
+ """
+ output_shape = tuple(output_shape)
+ output_ndim = len(output_shape)
+ input_shape = image.shape
+ if output_ndim > image.ndim:
+ # append dimensions to input_shape
+ input_shape += (1,) * (output_ndim - image.ndim)
+ image = np.reshape(image, input_shape)
+ elif output_ndim == image.ndim - 1:
+ # multichannel case: append shape of last axis
+ output_shape = output_shape + (image.shape[-1],)
+ elif output_ndim < image.ndim:
+ raise ValueError("output_shape length cannot be smaller than the image number of dimensions")
+
+ return image, output_shape
+
+
+def _clip_warp_output(input_image, output_image):
+ """Clip output image to range of values of input image.
+
+ Note that this function modifies the values of *output_image* in-place.
+
+ Taken from:
+ https://github.com/scikit-image/scikit-image/blob/b4b521d6f0a105aabeaa31699949f78453ca3511/skimage/transform/_warps.py#L640.
+
+ Args:
+ input_image : ndarray
+ Input image.
+ output_image : ndarray
+ Output image, which is modified in-place.
+ """
+ min_val = np.min(input_image)
+ if np.isnan(min_val):
+ # NaNs detected, use NaN-safe min/max
+ min_func = np.nanmin
+ max_func = np.nanmax
+ min_val = min_func(input_image)
+ else:
+ min_func = np.min
+ max_func = np.max
+ max_val = max_func(input_image)
+
+ output_image = np.clip(output_image, min_val, max_val)
+
+ return output_image
+
+
+class Owlv2ImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs an OWLv2 image processor.
+
+ Args:
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image to a square with gray pixels on the bottom and the right. Can be overridden by
+ `do_pad` in the `preprocess` method.
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden
+ by `do_resize` in the `preprocess` method.
+ size (`dict[str, int]` *optional*, defaults to `{"height": 960, "width": 960}`):
+ Size to resize the image to. Can be overridden by `size` in the `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling method to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_pad: bool = True,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_pad = do_pad
+ self.do_resize = do_resize
+ self.size = size if size is not None else {"height": 960, "width": 960}
+ self.resample = resample
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+
+ def pad(
+ self,
+ image: np.ndarray,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Pad an image to a square with gray pixels on the bottom and the right, as per the original OWLv2
+ implementation.
+
+ Args:
+ image (`np.ndarray`):
+ Image to pad.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input
+ image.
+ """
+ height, width = get_image_size(image)
+ size = max(height, width)
+ image = pad(
+ image=image,
+ padding=((0, size - height), (0, size - width)),
+ constant_values=0.5,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+
+ return image
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ anti_aliasing: bool = True,
+ anti_aliasing_sigma=None,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image as per the original implementation.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary containing the height and width to resize the image to.
+ anti_aliasing (`bool`, *optional*, defaults to `True`):
+ Whether to apply anti-aliasing when downsampling the image.
+ anti_aliasing_sigma (`float`, *optional*, defaults to `None`):
+ Standard deviation for Gaussian kernel when downsampling the image. If `None`, it will be calculated
+ automatically.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input
+ image.
+ """
+ requires_backends(self, "scipy")
+
+ output_shape = (size["height"], size["width"])
+ image = to_channel_dimension_format(image, ChannelDimension.LAST)
+ image, output_shape = _preprocess_resize_output_shape(image, output_shape)
+ input_shape = image.shape
+ factors = np.divide(input_shape, output_shape)
+
+ # Translate modes used by np.pad to those used by scipy.ndimage
+ ndi_mode = "mirror"
+ cval = 0
+ order = 1
+ if anti_aliasing:
+ if anti_aliasing_sigma is None:
+ anti_aliasing_sigma = np.maximum(0, (factors - 1) / 2)
+ else:
+ anti_aliasing_sigma = np.atleast_1d(anti_aliasing_sigma) * np.ones_like(factors)
+ if np.any(anti_aliasing_sigma < 0):
+ raise ValueError("Anti-aliasing standard deviation must be greater than or equal to zero")
+ elif np.any((anti_aliasing_sigma > 0) & (factors <= 1)):
+ warnings.warn(
+ "Anti-aliasing standard deviation greater than zero but not down-sampling along all axes"
+ )
+ filtered = ndi.gaussian_filter(image, anti_aliasing_sigma, cval=cval, mode=ndi_mode)
+ else:
+ filtered = image
+
+ zoom_factors = [1 / f for f in factors]
+ out = ndi.zoom(filtered, zoom_factors, order=order, mode=ndi_mode, cval=cval, grid_mode=True)
+
+ image = _clip_warp_output(image, out)
+
+ image = to_channel_dimension_format(image, input_data_format, ChannelDimension.LAST)
+ image = (
+ to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
+ )
+ return image
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_pad: Optional[bool] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image to a square with gray pixels on the bottom and the right.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size to resize the image to.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size) # for BC
+
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ # Here, pad and resize methods are different from the rest of image processors
+ # as they don't have any resampling in resize()
+ # or pad size in pad() (the maximum of (height, width) is taken instead).
+ # hence, these arguments don't need to be passed in validate_preprocess_arguments.
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ size=size,
+ )
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if do_rescale:
+ images = [
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_pad:
+ images = [self.pad(image=image, input_data_format=input_data_format) for image in images]
+
+ if do_resize:
+ images = [
+ self.resize(
+ image=image,
+ size=size,
+ input_data_format=input_data_format,
+ )
+ for image in images
+ ]
+
+ if do_normalize:
+ images = [
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_object_detection with OwlViT->Owlv2
+ def post_process_object_detection(
+ self,
+ outputs: "Owlv2ObjectDetectionOutput",
+ threshold: float = 0.1,
+ target_sizes: Optional[Union[TensorType, list[tuple]]] = None,
+ ):
+ """
+ Converts the raw output of [`Owlv2ForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
+ bottom_right_x, bottom_right_y) format.
+
+ Args:
+ outputs ([`Owlv2ObjectDetectionOutput`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*, defaults to 0.1):
+ Score threshold to keep object detection predictions.
+ target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
+
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the following keys:
+ - "scores": The confidence scores for each predicted box on the image.
+ - "labels": Indexes of the classes predicted by the model on the image.
+ - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format.
+ """
+ batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes
+ batch_size = len(batch_logits)
+
+ if target_sizes is not None and len(target_sizes) != batch_size:
+ raise ValueError("Make sure that you pass in as many target sizes as images")
+
+ # batch_logits of shape (batch_size, num_queries, num_classes)
+ batch_class_logits = torch.max(batch_logits, dim=-1)
+ batch_scores = torch.sigmoid(batch_class_logits.values)
+ batch_labels = batch_class_logits.indices
+
+ # Convert to [x0, y0, x1, y1] format
+ batch_boxes = center_to_corners_format(batch_boxes)
+
+ # Convert from relative [0, 1] to absolute [0, height] coordinates
+ if target_sizes is not None:
+ batch_boxes = _scale_boxes(batch_boxes, target_sizes)
+
+ results = []
+ for scores, labels, boxes in zip(batch_scores, batch_labels, batch_boxes):
+ keep = scores > threshold
+ scores = scores[keep]
+ labels = labels[keep]
+ boxes = boxes[keep]
+ results.append({"scores": scores, "labels": labels, "boxes": boxes})
+
+ return results
+
+ # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_image_guided_detection
+ def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_threshold=0.3, target_sizes=None):
+ """
+ Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO
+ api.
+
+ Args:
+ outputs ([`OwlViTImageGuidedObjectDetectionOutput`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*, defaults to 0.0):
+ Minimum confidence threshold to use to filter out predicted boxes.
+ nms_threshold (`float`, *optional*, defaults to 0.3):
+ IoU threshold for non-maximum suppression of overlapping boxes.
+ target_sizes (`torch.Tensor`, *optional*):
+ Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
+ the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to
+ None, predictions will not be unnormalized.
+
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model. All labels are set to None as
+ `OwlViTForObjectDetection.image_guided_detection` perform one-shot object detection.
+ """
+ logits, target_boxes = outputs.logits, outputs.target_pred_boxes
+
+ if target_sizes is not None and len(logits) != len(target_sizes):
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
+ if target_sizes is not None and target_sizes.shape[1] != 2:
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
+
+ probs = torch.max(logits, dim=-1)
+ scores = torch.sigmoid(probs.values)
+
+ # Convert to [x0, y0, x1, y1] format
+ target_boxes = center_to_corners_format(target_boxes)
+
+ # Apply non-maximum suppression (NMS)
+ if nms_threshold < 1.0:
+ for idx in range(target_boxes.shape[0]):
+ for i in torch.argsort(-scores[idx]):
+ if not scores[idx][i]:
+ continue
+
+ ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0]
+ ious[i] = -1.0 # Mask self-IoU.
+ scores[idx][ious > nms_threshold] = 0.0
+
+ # Convert from relative [0, 1] to absolute [0, height] coordinates
+ if target_sizes is not None:
+ target_boxes = _scale_boxes(target_boxes, target_sizes)
+
+ # Compute box display alphas based on prediction scores
+ results = []
+ alphas = torch.zeros_like(scores)
+
+ for idx in range(target_boxes.shape[0]):
+ # Select scores for boxes matching the current query:
+ query_scores = scores[idx]
+ if not query_scores.nonzero().numel():
+ continue
+
+ # Apply threshold on scores before scaling
+ query_scores[query_scores < threshold] = 0.0
+
+ # Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1.
+ # All other boxes will either belong to a different query, or will not be shown.
+ max_score = torch.max(query_scores) + 1e-6
+ query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9)
+ query_alphas = torch.clip(query_alphas, 0.0, 1.0)
+ alphas[idx] = query_alphas
+
+ mask = alphas[idx] > 0
+ box_scores = alphas[idx][mask]
+ boxes = target_boxes[idx][mask]
+ results.append({"scores": box_scores, "labels": None, "boxes": boxes})
+
+ return results
+
+
+__all__ = ["Owlv2ImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/image_processing_owlv2_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/image_processing_owlv2_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..359d241686ec017f80945bc1a9eda0f35ca079b6
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/image_processing_owlv2_fast.py
@@ -0,0 +1,409 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/owlv2/modular_owlv2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_owlv2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from typing import TYPE_CHECKING, Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs
+from ...image_transforms import center_to_corners_format, group_images_by_shape, reorder_images
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+)
+from ...processing_utils import Unpack
+from ...utils import TensorType, auto_docstring
+from .image_processing_owlv2 import _scale_boxes, box_iou
+
+
+if TYPE_CHECKING:
+ from .modeling_owlv2 import Owlv2ObjectDetectionOutput
+
+
+class Owlv2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): ...
+
+
+@auto_docstring
+class Owlv2ImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ size = {"height": 960, "width": 960}
+ default_to_square = True
+ crop_size = None
+ do_resize = True
+ do_center_crop = None
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = None
+ model_input_names = ["pixel_values"]
+ rescale_factor = 1 / 255
+ do_pad = True
+ valid_kwargs = Owlv2FastImageProcessorKwargs
+
+ def post_process(self, outputs, target_sizes):
+ """
+ Converts the raw output of [`Owlv2ForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
+ bottom_right_x, bottom_right_y) format.
+
+ Args:
+ outputs ([`Owlv2ObjectDetectionOutput`]):
+ Raw outputs of the model.
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
+ Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
+ image size (before any data augmentation). For visualization, this should be the image size after data
+ augment, but before padding.
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model.
+ """
+ # TODO: (amy) add support for other frameworks
+ warnings.warn(
+ "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
+ " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
+ FutureWarning,
+ )
+
+ logits, boxes = outputs.logits, outputs.pred_boxes
+
+ if len(logits) != len(target_sizes):
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
+ if target_sizes.shape[1] != 2:
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
+
+ probs = torch.max(logits, dim=-1)
+ scores = torch.sigmoid(probs.values)
+ labels = probs.indices
+
+ # Convert to [x0, y0, x1, y1] format
+ boxes = center_to_corners_format(boxes)
+
+ # Convert from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
+
+ return results
+
+ def post_process_object_detection(
+ self,
+ outputs: "Owlv2ObjectDetectionOutput",
+ threshold: float = 0.1,
+ target_sizes: Optional[Union[TensorType, list[tuple]]] = None,
+ ):
+ """
+ Converts the raw output of [`Owlv2ForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
+ bottom_right_x, bottom_right_y) format.
+
+ Args:
+ outputs ([`Owlv2ObjectDetectionOutput`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*, defaults to 0.1):
+ Score threshold to keep object detection predictions.
+ target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
+
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the following keys:
+ - "scores": The confidence scores for each predicted box on the image.
+ - "labels": Indexes of the classes predicted by the model on the image.
+ - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format.
+ """
+ batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes
+ batch_size = len(batch_logits)
+
+ if target_sizes is not None and len(target_sizes) != batch_size:
+ raise ValueError("Make sure that you pass in as many target sizes as images")
+
+ # batch_logits of shape (batch_size, num_queries, num_classes)
+ batch_class_logits = torch.max(batch_logits, dim=-1)
+ batch_scores = torch.sigmoid(batch_class_logits.values)
+ batch_labels = batch_class_logits.indices
+
+ # Convert to [x0, y0, x1, y1] format
+ batch_boxes = center_to_corners_format(batch_boxes)
+
+ # Convert from relative [0, 1] to absolute [0, height] coordinates
+ if target_sizes is not None:
+ batch_boxes = _scale_boxes(batch_boxes, target_sizes)
+
+ results = []
+ for scores, labels, boxes in zip(batch_scores, batch_labels, batch_boxes):
+ keep = scores > threshold
+ scores = scores[keep]
+ labels = labels[keep]
+ boxes = boxes[keep]
+ results.append({"scores": scores, "labels": labels, "boxes": boxes})
+
+ return results
+
+ def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_threshold=0.3, target_sizes=None):
+ """
+ Converts the output of [`Owlv2ForObjectDetection.image_guided_detection`] into the format expected by the COCO
+ api.
+
+ Args:
+ outputs ([`Owlv2ImageGuidedObjectDetectionOutput`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*, defaults to 0.0):
+ Minimum confidence threshold to use to filter out predicted boxes.
+ nms_threshold (`float`, *optional*, defaults to 0.3):
+ IoU threshold for non-maximum suppression of overlapping boxes.
+ target_sizes (`torch.Tensor`, *optional*):
+ Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
+ the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to
+ None, predictions will not be unnormalized.
+
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model. All labels are set to None as
+ `Owlv2ForObjectDetection.image_guided_detection` perform one-shot object detection.
+ """
+ logits, target_boxes = outputs.logits, outputs.target_pred_boxes
+
+ if target_sizes is not None and len(logits) != len(target_sizes):
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
+ if target_sizes is not None and target_sizes.shape[1] != 2:
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
+
+ probs = torch.max(logits, dim=-1)
+ scores = torch.sigmoid(probs.values)
+
+ # Convert to [x0, y0, x1, y1] format
+ target_boxes = center_to_corners_format(target_boxes)
+
+ # Apply non-maximum suppression (NMS)
+ if nms_threshold < 1.0:
+ for idx in range(target_boxes.shape[0]):
+ for i in torch.argsort(-scores[idx]):
+ if not scores[idx][i]:
+ continue
+
+ ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0]
+ ious[i] = -1.0 # Mask self-IoU.
+ scores[idx][ious > nms_threshold] = 0.0
+
+ # Convert from relative [0, 1] to absolute [0, height] coordinates
+ if target_sizes is not None:
+ target_boxes = _scale_boxes(target_boxes, target_sizes)
+
+ # Compute box display alphas based on prediction scores
+ results = []
+ alphas = torch.zeros_like(scores)
+
+ for idx in range(target_boxes.shape[0]):
+ # Select scores for boxes matching the current query:
+ query_scores = scores[idx]
+ if not query_scores.nonzero().numel():
+ continue
+
+ # Apply threshold on scores before scaling
+ query_scores[query_scores < threshold] = 0.0
+
+ # Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1.
+ # All other boxes will either belong to a different query, or will not be shown.
+ max_score = torch.max(query_scores) + 1e-6
+ query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9)
+ query_alphas = torch.clip(query_alphas, 0.0, 1.0)
+ alphas[idx] = query_alphas
+
+ mask = alphas[idx] > 0
+ box_scores = alphas[idx][mask]
+ boxes = target_boxes[idx][mask]
+ results.append({"scores": box_scores, "labels": None, "boxes": boxes})
+
+ return results
+
+ def __init__(self, **kwargs: Unpack[Owlv2FastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[Owlv2FastImageProcessorKwargs]):
+ return super().preprocess(images, **kwargs)
+
+ def _pad_images(self, images: "torch.Tensor", constant_value: float = 0.5) -> "torch.Tensor":
+ """
+ Pad an image with zeros to the given size.
+ """
+ height, width = images.shape[-2:]
+ size = max(height, width)
+ pad_bottom = size - height
+ pad_right = size - width
+
+ padding = (0, 0, pad_right, pad_bottom)
+ padded_image = F.pad(images, padding, fill=constant_value)
+ return padded_image
+
+ def pad(
+ self,
+ images: list["torch.Tensor"],
+ disable_grouping: Optional[bool],
+ constant_value: float = 0.5,
+ **kwargs,
+ ) -> list["torch.Tensor"]:
+ """
+ Unlike the Base class `self.pad` where all images are padded to the maximum image size,
+ Owlv2 pads an image to square.
+ """
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ stacked_images = self._pad_images(
+ stacked_images,
+ constant_value=constant_value,
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+
+ return processed_images
+
+ def resize(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ anti_aliasing: bool = True,
+ anti_aliasing_sigma=None,
+ **kwargs,
+ ) -> "torch.Tensor":
+ """
+ Resize an image as per the original implementation.
+
+ Args:
+ image (`Tensor`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary containing the height and width to resize the image to.
+ anti_aliasing (`bool`, *optional*, defaults to `True`):
+ Whether to apply anti-aliasing when downsampling the image.
+ anti_aliasing_sigma (`float`, *optional*, defaults to `None`):
+ Standard deviation for Gaussian kernel when downsampling the image. If `None`, it will be calculated
+ automatically.
+ """
+ output_shape = (size.height, size.width)
+
+ input_shape = image.shape
+
+ # select height and width from input tensor
+ factors = torch.tensor(input_shape[2:]).to(image.device) / torch.tensor(output_shape).to(image.device)
+
+ if anti_aliasing:
+ if anti_aliasing_sigma is None:
+ anti_aliasing_sigma = ((factors - 1) / 2).clamp(min=0)
+ else:
+ anti_aliasing_sigma = torch.atleast_1d(anti_aliasing_sigma) * torch.ones_like(factors)
+ if torch.any(anti_aliasing_sigma < 0):
+ raise ValueError("Anti-aliasing standard deviation must be greater than or equal to zero")
+ elif torch.any((anti_aliasing_sigma > 0) & (factors <= 1)):
+ warnings.warn(
+ "Anti-aliasing standard deviation greater than zero but not down-sampling along all axes"
+ )
+ if torch.any(anti_aliasing_sigma == 0):
+ filtered = image
+ else:
+ kernel_sizes = 2 * torch.ceil(3 * anti_aliasing_sigma).int() + 1
+
+ filtered = F.gaussian_blur(
+ image, (kernel_sizes[0], kernel_sizes[1]), sigma=anti_aliasing_sigma.tolist()
+ )
+
+ else:
+ filtered = image
+
+ out = F.resize(filtered, size=(size.height, size.width), antialias=False)
+
+ return out
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_pad: bool,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+
+ for shape, stacked_images in grouped_images.items():
+ # Rescale images before other operations as done in original implementation
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, False, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+
+ if do_pad:
+ processed_images = self.pad(processed_images, constant_value=0.5, disable_grouping=disable_grouping)
+
+ grouped_images, grouped_images_index = group_images_by_shape(
+ processed_images, disable_grouping=disable_grouping
+ )
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ resized_stack = self.resize(
+ image=stacked_images,
+ size=size,
+ interpolation=interpolation,
+ input_data_format=ChannelDimension.FIRST,
+ )
+ resized_images_grouped[shape] = resized_stack
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, False, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+
+__all__ = ["Owlv2ImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/modeling_owlv2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/modeling_owlv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..715df44f01f0f77eec3a76e845a1aee629e72cda
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/modeling_owlv2.py
@@ -0,0 +1,1708 @@
+# coding=utf-8
+# Copyright 2023 Google AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch OWLv2 model."""
+
+from dataclasses import dataclass
+from functools import lru_cache
+from typing import Any, Optional, Union
+
+import torch
+from torch import Tensor, nn
+
+from ...activations import ACT2FN
+from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ ModelOutput,
+ auto_docstring,
+ filter_out_non_signature_kwargs,
+ is_vision_available,
+ logging,
+ torch_int,
+)
+from .configuration_owlv2 import Owlv2Config, Owlv2TextConfig, Owlv2VisionConfig
+
+
+if is_vision_available():
+ from transformers.image_transforms import center_to_corners_format
+
+
+logger = logging.get_logger(__name__)
+
+
+# See all Owlv2 models at https://huggingface.co/models?filter=owlv2
+
+
+# Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlv2
+def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
+
+
+# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlv2
+def owlv2_loss(similarity: torch.Tensor) -> torch.Tensor:
+ caption_loss = contrastive_loss(similarity)
+ image_loss = contrastive_loss(similarity.t())
+ return (caption_loss + image_loss) / 2.0
+
+
+@dataclass
+@auto_docstring
+class Owlv2Output(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
+ Contrastive loss for image-text similarity.
+ logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
+ similarity scores.
+ logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
+ similarity scores.
+ text_embeds (`torch.FloatTensor` of shape `(batch_size * num_max_text_queries, output_dim`):
+ The text embeddings obtained by applying the projection layer to the pooled output of [`Owlv2TextModel`].
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The image embeddings obtained by applying the projection layer to the pooled output of
+ [`Owlv2VisionModel`].
+ text_model_output (tuple[`BaseModelOutputWithPooling`]):
+ The output of the [`Owlv2TextModel`].
+ vision_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`Owlv2VisionModel`].
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits_per_image: Optional[torch.FloatTensor] = None
+ logits_per_text: Optional[torch.FloatTensor] = None
+ text_embeds: Optional[torch.FloatTensor] = None
+ image_embeds: Optional[torch.FloatTensor] = None
+ text_model_output: BaseModelOutputWithPooling = None
+ vision_model_output: BaseModelOutputWithPooling = None
+
+ def to_tuple(self) -> tuple[Any]:
+ return tuple(
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
+ for k in self.keys()
+ )
+
+
+# Copied from transformers.loss.loss_for_object_detection._upcast
+def _upcast(t: Tensor) -> Tensor:
+ # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
+ if t.is_floating_point():
+ return t if t.dtype in (torch.float32, torch.float64) else t.float()
+ else:
+ return t if t.dtype in (torch.int32, torch.int64) else t.int()
+
+
+# Copied from transformers.loss.loss_for_object_detection.box_area
+def box_area(boxes: Tensor) -> Tensor:
+ """
+ Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
+
+ Args:
+ boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
+ Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
+ < x2` and `0 <= y1 < y2`.
+
+ Returns:
+ `torch.FloatTensor`: a tensor containing the area for each box.
+ """
+ boxes = _upcast(boxes)
+ return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
+
+
+# Copied from transformers.loss.loss_for_object_detection.box_iou
+def box_iou(boxes1, boxes2):
+ area1 = box_area(boxes1)
+ area2 = box_area(boxes2)
+
+ left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
+ right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
+
+ width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
+ inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
+
+ union = area1[:, None] + area2 - inter
+
+ iou = inter / union
+ return iou, union
+
+
+# Copied from transformers.loss.loss_for_object_detection.generalized_box_iou
+def generalized_box_iou(boxes1, boxes2):
+ """
+ Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
+
+ Returns:
+ `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
+ """
+ # degenerate boxes gives inf / nan results
+ # so do an early check
+ if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
+ raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
+ if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
+ raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
+ iou, union = box_iou(boxes1, boxes2)
+
+ top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
+ bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
+
+ width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
+ area = width_height[:, :, 0] * width_height[:, :, 1]
+
+ return iou - (area - union) / area
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`Owlv2ForObjectDetection`].
+ """
+)
+class Owlv2ObjectDetectionOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
+ scale-invariant IoU loss.
+ loss_dict (`Dict`, *optional*):
+ A dictionary containing the individual losses. Useful for logging.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
+ Classification logits (including no-object) for all queries.
+ objectness_logits (`torch.FloatTensor` of shape `(batch_size, num_patches, 1)`):
+ The objectness logits of all image patches. OWL-ViT represents images as a set of image patches where the
+ total number of patches is (image_size / patch_size)**2.
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
+ possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to retrieve the
+ unnormalized bounding boxes.
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
+ The text embeddings obtained by applying the projection layer to the pooled output of [`Owlv2TextModel`].
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
+ Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes image
+ embeddings for each patch.
+ class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
+ Class embeddings of all image patches. OWLv2 represents images as a set of image patches where the total
+ number of patches is (image_size / patch_size)**2.
+ text_model_output (tuple[`BaseModelOutputWithPooling`]):
+ The output of the [`Owlv2TextModel`].
+ vision_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`Owlv2VisionModel`].
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ loss_dict: Optional[dict] = None
+ logits: Optional[torch.FloatTensor] = None
+ objectness_logits: Optional[torch.FloatTensor] = None
+ pred_boxes: Optional[torch.FloatTensor] = None
+ text_embeds: Optional[torch.FloatTensor] = None
+ image_embeds: Optional[torch.FloatTensor] = None
+ class_embeds: Optional[torch.FloatTensor] = None
+ text_model_output: BaseModelOutputWithPooling = None
+ vision_model_output: BaseModelOutputWithPooling = None
+
+ def to_tuple(self) -> tuple[Any]:
+ return tuple(
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
+ for k in self.keys()
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`Owlv2ForObjectDetection.image_guided_detection`].
+ """
+)
+# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTImageGuidedObjectDetectionOutput with OwlViT->Owlv2,OWL-ViT->OWLv2
+class Owlv2ImageGuidedObjectDetectionOutput(ModelOutput):
+ r"""
+ logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
+ Classification logits (including no-object) for all queries.
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
+ Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes
+ image embeddings for each patch.
+ query_image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
+ Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes
+ image embeddings for each patch.
+ target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
+ values are normalized in [0, 1], relative to the size of each individual target image in the batch
+ (disregarding possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to
+ retrieve the unnormalized bounding boxes.
+ query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
+ values are normalized in [0, 1], relative to the size of each individual query image in the batch
+ (disregarding possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to
+ retrieve the unnormalized bounding boxes.
+ class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
+ Class embeddings of all image patches. OWLv2 represents images as a set of image patches where the total
+ number of patches is (image_size / patch_size)**2.
+ text_model_output (tuple[`BaseModelOutputWithPooling`]):
+ The output of the [`Owlv2TextModel`].
+ vision_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`Owlv2VisionModel`].
+ """
+
+ logits: Optional[torch.FloatTensor] = None
+ image_embeds: Optional[torch.FloatTensor] = None
+ query_image_embeds: Optional[torch.FloatTensor] = None
+ target_pred_boxes: Optional[torch.FloatTensor] = None
+ query_pred_boxes: Optional[torch.FloatTensor] = None
+ class_embeds: Optional[torch.FloatTensor] = None
+ text_model_output: BaseModelOutputWithPooling = None
+ vision_model_output: BaseModelOutputWithPooling = None
+
+ def to_tuple(self) -> tuple[Any]:
+ return tuple(
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
+ for k in self.keys()
+ )
+
+
+# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTVisionEmbeddings with OwlViT->Owlv2
+class Owlv2VisionEmbeddings(nn.Module):
+ def __init__(self, config: Owlv2VisionConfig):
+ super().__init__()
+ self.patch_size = config.patch_size
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.class_embedding = nn.Parameter(torch.randn(config.hidden_size))
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=config.patch_size,
+ stride=config.patch_size,
+ bias=False,
+ )
+
+ self.num_patches = (config.image_size // config.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.interpolate_pos_encoding
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ position_embedding = self.position_embedding.weight.unsqueeze(0)
+ num_positions = position_embedding.shape[1] - 1
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embedding(self.position_ids)
+
+ class_pos_embed = position_embedding[:, :1]
+ patch_pos_embed = position_embedding[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTTextEmbeddings with OwlViT->Owlv2
+class Owlv2TextEmbeddings(nn.Module):
+ def __init__(self, config: Owlv2TextConfig):
+ super().__init__()
+ self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+
+ return embeddings
+
+
+# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTAttention with OwlViT->Owlv2
+class Owlv2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, embed_dim = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scale
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ # For int8 compatibility, sometimes the `attn_probs` are in `fp32`
+ attn_probs = attn_probs.to(value_states.dtype)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Owlv2
+class Owlv2MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Owlv2
+class Owlv2EncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Owlv2Config):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = Owlv2Attention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = Owlv2MLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ causal_attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTPreTrainedModel with OwlViT->Owlv2,owlvit->owlv2
+class Owlv2PreTrainedModel(PreTrainedModel):
+ config: Owlv2Config
+ base_model_prefix = "owlv2"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Owlv2EncoderLayer"]
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights"""
+ factor = self.config.initializer_factor
+ if isinstance(module, Owlv2TextEmbeddings):
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
+ elif isinstance(module, Owlv2VisionEmbeddings):
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
+ elif isinstance(module, Owlv2Attention):
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+ out_proj_std = (module.embed_dim**-0.5) * factor
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
+ elif isinstance(module, Owlv2MLP):
+ in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
+ nn.init.normal_(module.fc1.weight, std=fc_std)
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
+ elif isinstance(module, Owlv2Model):
+ nn.init.normal_(
+ module.text_projection.weight,
+ std=module.text_embed_dim**-0.5 * factor,
+ )
+ nn.init.normal_(
+ module.visual_projection.weight,
+ std=module.vision_embed_dim**-0.5 * factor,
+ )
+ module.logit_scale.data.fill_(self.config.logit_scale_init_value)
+ if isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=factor)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTEncoder with OwlViT->Owlv2
+class Owlv2Encoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`Owlv2EncoderLayer`].
+
+ Args:
+ config: Owlv2Config
+ """
+
+ def __init__(self, config: Owlv2Config):
+ super().__init__()
+ self.layers = nn.ModuleList([Owlv2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`).
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTTextTransformer with OWLVIT->OWLV2,OwlViT->Owlv2
+class Owlv2TextTransformer(nn.Module):
+ def __init__(self, config: Owlv2TextConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+ self.embeddings = Owlv2TextEmbeddings(config)
+ self.encoder = Owlv2Encoder(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
+ IDs?](../glossary#input-ids)
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
+
+ # num_samples, seq_len = input_shape where num_samples = batch_size * num_max_text_queries
+ # OWLV2's text model uses causal mask, prepare it here.
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
+ causal_attention_mask = _create_4d_causal_attention_mask(
+ input_shape, hidden_states.dtype, device=hidden_states.device
+ )
+ # expand attention_mask
+ if attention_mask is not None:
+ # [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+ # take features from the end of tokens embedding (end of token is the highest number in each sequence)
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
+ pooled_output = last_hidden_state[
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
+ input_ids.to(torch.int).argmax(dim=-1).to(last_hidden_state.device),
+ ]
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTTextModel with google/owlvit-base-patch32->google/owlv2-base-patch16, OWLVIT->OWLV2,OwlViT->Owlv2
+class Owlv2TextModel(Owlv2PreTrainedModel):
+ config: Owlv2TextConfig
+
+ def __init__(self, config: Owlv2TextConfig):
+ super().__init__(config)
+ self.text_model = Owlv2TextTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.text_model.embeddings.token_embedding
+
+ def set_input_embeddings(self, value):
+ self.text_model.embeddings.token_embedding = value
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
+ IDs?](../glossary#input-ids)
+
+ Examples:
+ ```python
+ >>> from transformers import AutoProcessor, Owlv2TextModel
+
+ >>> model = Owlv2TextModel.from_pretrained("google/owlv2-base-patch16")
+ >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16")
+ >>> inputs = processor(
+ ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
+ ... )
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+
+ # Get embeddings for all text queries in all batch samples
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTVisionTransformer with OWLVIT->OWLV2,OwlViT->Owlv2
+class Owlv2VisionTransformer(nn.Module):
+ def __init__(self, config: Owlv2VisionConfig):
+ super().__init__()
+ self.config = config
+
+ self.embeddings = Owlv2VisionEmbeddings(config)
+ self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.encoder = Owlv2Encoder(config)
+ self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = False,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Cast the input to the expected `dtype`
+ expected_input_dtype = self.embeddings.patch_embedding.weight.dtype
+ pixel_values = pixel_values.to(expected_input_dtype)
+
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+ hidden_states = self.pre_layernorm(hidden_states)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ pooled_output = last_hidden_state[:, 0, :]
+
+ pooled_output = self.post_layernorm(pooled_output)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTVisionModel with OWLVIT->OWLV2,OwlViT->Owlv2,google/owlvit-base-patch32->google/owlv2-base-patch16
+class Owlv2VisionModel(Owlv2PreTrainedModel):
+ config: Owlv2VisionConfig
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: Owlv2VisionConfig):
+ super().__init__(config)
+ self.vision_model = Owlv2VisionTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ r"""
+ Examples:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Owlv2VisionModel
+
+ >>> model = Owlv2VisionModel.from_pretrained("google/owlv2-base-patch16")
+ >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16")
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
+ ```"""
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+
+@auto_docstring
+# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTModel with google/owlvit-base-patch32->google/owlv2-base-patch16-ensemble, OWLVIT->OWLV2,OwlViT->Owlv2,owlvit->owlv2,OWL-ViT->OWLv2
+class Owlv2Model(Owlv2PreTrainedModel):
+ config: Owlv2Config
+
+ def __init__(self, config: Owlv2Config):
+ super().__init__(config)
+
+ if not isinstance(config.text_config, Owlv2TextConfig):
+ raise TypeError(
+ "config.text_config is expected to be of type Owlv2TextConfig but is of type"
+ f" {type(config.text_config)}."
+ )
+
+ if not isinstance(config.vision_config, Owlv2VisionConfig):
+ raise TypeError(
+ "config.vision_config is expected to be of type Owlv2VisionConfig but is of type"
+ f" {type(config.vision_config)}."
+ )
+
+ text_config = config.text_config
+ vision_config = config.vision_config
+
+ self.projection_dim = config.projection_dim
+ self.text_embed_dim = text_config.hidden_size
+ self.vision_embed_dim = vision_config.hidden_size
+
+ self.text_model = Owlv2TextTransformer(text_config)
+ self.vision_model = Owlv2VisionTransformer(vision_config)
+
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
+ self.logit_scale = nn.Parameter(torch.tensor(config.logit_scale_init_value))
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @filter_out_non_signature_kwargs()
+ @auto_docstring
+ def get_text_features(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
+ IDs?](../glossary#input-ids)
+
+ Returns:
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output of [`Owlv2TextModel`].
+
+ Examples:
+ ```python
+ >>> import torch
+ >>> from transformers import AutoProcessor, Owlv2Model
+
+ >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble")
+ >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
+ >>> inputs = processor(
+ ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
+ ... )
+ >>> with torch.inference_mode():
+ ... text_features = model.get_text_features(**inputs)
+ ```"""
+ # Get embeddings for all text queries in all batch samples
+ text_outputs: BaseModelOutputWithPooling = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
+ text_features = self.text_projection(text_outputs.pooler_output)
+
+ return text_features
+
+ @filter_out_non_signature_kwargs()
+ @auto_docstring
+ def get_image_features(
+ self,
+ pixel_values: torch.Tensor,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
+ applying the projection layer to the pooled output of [`Owlv2VisionModel`].
+
+ Examples:
+ ```python
+ >>> import torch
+ >>> from transformers.image_utils import load_image
+ >>> from transformers import AutoProcessor, Owlv2Model
+
+ >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble")
+ >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = load_image(url)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+ >>> with torch.inference_mode():
+ ... image_features = model.get_image_features(**inputs)
+ ```"""
+ vision_outputs: BaseModelOutputWithPooling = self.vision_model(
+ pixel_values=pixel_values,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+ image_features = self.visual_projection(vision_outputs.pooler_output)
+
+ return image_features
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ return_loss: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ return_base_image_embeds: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, Owlv2Output]:
+ r"""
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+ return_base_image_embeds (`bool`, *optional*):
+ Whether or not to return the base image embeddings.
+
+ Examples:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Owlv2Model
+
+ >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble")
+ >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
+ ```"""
+ # Use OWLv2 model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ # Get embeddings for all text queries in all batch samples
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ text_embeds = text_outputs[1]
+ text_embeds = self.text_projection(text_embeds)
+ image_embeds = vision_outputs[1]
+ image_embeds = self.visual_projection(image_embeds)
+
+ # normalized features
+ image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
+ text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
+
+ # cosine similarity as logits and set it on the correct device
+ logit_scale = self.logit_scale.exp().to(image_embeds.device)
+
+ logits_per_text = torch.matmul(text_embeds_norm, image_embeds.t()) * logit_scale
+ logits_per_image = logits_per_text.t()
+
+ loss = None
+ if return_loss:
+ loss = owlv2_loss(logits_per_text)
+
+ text_embeds = text_embeds_norm
+
+ if not return_dict:
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
+ return ((loss,) + output) if loss is not None else output
+
+ return Owlv2Output(
+ loss=loss,
+ logits_per_image=logits_per_image,
+ logits_per_text=logits_per_text,
+ text_embeds=text_embeds,
+ image_embeds=image_embeds,
+ text_model_output=text_outputs,
+ vision_model_output=vision_outputs,
+ )
+
+
+# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTBoxPredictionHead with OwlViT->Owlv2
+class Owlv2BoxPredictionHead(nn.Module):
+ def __init__(self, config: Owlv2Config, out_dim: int = 4):
+ super().__init__()
+
+ width = config.vision_config.hidden_size
+ self.dense0 = nn.Linear(width, width)
+ self.dense1 = nn.Linear(width, width)
+ self.gelu = nn.GELU()
+ self.dense2 = nn.Linear(width, out_dim)
+
+ def forward(self, image_features: torch.Tensor) -> torch.FloatTensor:
+ output = self.dense0(image_features)
+ output = self.gelu(output)
+ output = self.dense1(output)
+ output = self.gelu(output)
+ output = self.dense2(output)
+ return output
+
+
+# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTClassPredictionHead with OwlViT->Owlv2
+class Owlv2ClassPredictionHead(nn.Module):
+ def __init__(self, config: Owlv2Config):
+ super().__init__()
+
+ out_dim = config.text_config.hidden_size
+ self.query_dim = config.vision_config.hidden_size
+
+ self.dense0 = nn.Linear(self.query_dim, out_dim)
+ self.logit_shift = nn.Linear(self.query_dim, 1)
+ self.logit_scale = nn.Linear(self.query_dim, 1)
+ self.elu = nn.ELU()
+
+ def forward(
+ self,
+ image_embeds: torch.FloatTensor,
+ query_embeds: Optional[torch.FloatTensor],
+ query_mask: Optional[torch.Tensor],
+ ) -> tuple[torch.FloatTensor]:
+ image_class_embeds = self.dense0(image_embeds)
+ if query_embeds is None:
+ device = image_class_embeds.device
+ batch_size, num_patches = image_class_embeds.shape[:2]
+ pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device)
+ return (pred_logits, image_class_embeds)
+
+ # Normalize image and text features
+ image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6)
+ query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6)
+
+ # Get class predictions
+ pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)
+
+ # Apply a learnable shift and scale to logits
+ logit_shift = self.logit_shift(image_embeds)
+ logit_scale = self.logit_scale(image_embeds)
+ logit_scale = self.elu(logit_scale) + 1
+ pred_logits = (pred_logits + logit_shift) * logit_scale
+
+ if query_mask is not None:
+ if query_mask.ndim > 1:
+ query_mask = torch.unsqueeze(query_mask, dim=-2)
+
+ pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits)
+ pred_logits = pred_logits.to(torch.float32)
+
+ return (pred_logits, image_class_embeds)
+
+
+class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
+ config: Owlv2Config
+
+ def __init__(self, config: Owlv2Config):
+ super().__init__(config)
+
+ self.owlv2 = Owlv2Model(config)
+ self.class_head = Owlv2ClassPredictionHead(config)
+ self.box_head = Owlv2BoxPredictionHead(config)
+ self.objectness_head = Owlv2BoxPredictionHead(config, out_dim=1)
+
+ self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
+ self.sigmoid = nn.Sigmoid()
+ self.config = config
+ self.num_patches_height = self.config.vision_config.image_size // self.config.vision_config.patch_size
+ self.num_patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size
+ self.box_bias = self.compute_box_bias(self.num_patches_height, self.num_patches_width)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @staticmethod
+ # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.normalize_grid_corner_coordinates
+ def normalize_grid_corner_coordinates(num_patches_height: int, num_patches_width: int) -> torch.Tensor:
+ # Create grid coordinates using torch
+ x_coordinates = torch.arange(1, num_patches_width + 1, dtype=torch.float32)
+ y_coordinates = torch.arange(1, num_patches_height + 1, dtype=torch.float32)
+ xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy")
+
+ # Stack the coordinates and divide by their respective patch counts
+ box_coordinates = torch.stack((xx, yy), dim=-1)
+ box_coordinates[..., 0] /= num_patches_width
+ box_coordinates[..., 1] /= num_patches_height
+
+ # Flatten (h, w, 2) -> (h*w, 2)
+ box_coordinates = box_coordinates.view(-1, 2)
+
+ return box_coordinates
+
+ def objectness_predictor(self, image_features: torch.FloatTensor) -> torch.FloatTensor:
+ """Predicts the probability that each image feature token is an object.
+
+ Args:
+ image_features (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_dim)`)):
+ Features extracted from the image.
+ Returns:
+ Objectness scores.
+ """
+ image_features = image_features.detach()
+ objectness_logits = self.objectness_head(image_features)
+ objectness_logits = objectness_logits[..., 0]
+ return objectness_logits
+
+ @lru_cache(maxsize=2)
+ # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.compute_box_bias
+ def compute_box_bias(
+ self, num_patches_height: int, num_patches_width: int, feature_map: Optional[torch.FloatTensor] = None
+ ) -> torch.Tensor:
+ if feature_map is not None:
+ raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead")
+ # The box center is biased to its position on the feature grid
+ box_coordinates = self.normalize_grid_corner_coordinates(num_patches_height, num_patches_width)
+ box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
+
+ # Unnormalize xy
+ box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
+
+ # The box size is biased to the patch size
+ box_size = torch.full_like(box_coord_bias, 1.0)
+ box_size[..., 0] /= num_patches_width
+ box_size[..., 1] /= num_patches_height
+ box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
+
+ # Compute box bias
+ box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1)
+ return box_bias
+
+ # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.box_predictor
+ def box_predictor(
+ self,
+ image_feats: torch.FloatTensor,
+ feature_map: torch.FloatTensor,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.FloatTensor:
+ """
+ Args:
+ image_feats:
+ Features extracted from the image, returned by the `image_text_embedder` method.
+ feature_map:
+ A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
+ interpolate_pos_encoding:
+ Whether to interpolate the pre-trained position encodings.
+ Returns:
+ pred_boxes:
+ List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
+ """
+ # Bounding box detection head [batch_size, num_boxes, 4].
+ pred_boxes = self.box_head(image_feats)
+
+ # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
+ if interpolate_pos_encoding:
+ _, num_patches_height, num_patches_width, _ = feature_map.shape
+ box_bias = self.compute_box_bias(num_patches_height, num_patches_width)
+ else:
+ box_bias = self.box_bias
+
+ box_bias = box_bias.to(feature_map.device)
+ pred_boxes += box_bias
+ pred_boxes = self.sigmoid(pred_boxes)
+ return pred_boxes
+
+ # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.class_predictor
+ def class_predictor(
+ self,
+ image_feats: torch.FloatTensor,
+ query_embeds: Optional[torch.FloatTensor] = None,
+ query_mask: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.FloatTensor]:
+ """
+ Args:
+ image_feats:
+ Features extracted from the `image_text_embedder`.
+ query_embeds:
+ Text query embeddings.
+ query_mask:
+ Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
+ """
+ (pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)
+
+ return (pred_logits, image_class_embeds)
+
+ # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.image_text_embedder with owlvit->owlv2
+ def image_text_embedder(
+ self,
+ input_ids: torch.Tensor,
+ pixel_values: torch.FloatTensor,
+ attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> tuple[torch.FloatTensor]:
+ # Encode text and image
+ outputs = self.owlv2(
+ pixel_values=pixel_values,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=True,
+ )
+
+ if interpolate_pos_encoding:
+ _, _, height, width = pixel_values.shape
+ num_patches_height = height // self.config.vision_config.patch_size
+ num_patches_width = width // self.config.vision_config.patch_size
+ else:
+ num_patches_height = self.num_patches_height
+ num_patches_width = self.num_patches_width
+
+ # Get image embeddings
+ last_hidden_state = outputs.vision_model_output[0]
+ image_embeds = self.owlv2.vision_model.post_layernorm(last_hidden_state)
+
+ # Resize class token
+ class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
+
+ # Merge image embedding with class tokens
+ image_embeds = image_embeds[:, 1:, :] * class_token_out
+ image_embeds = self.layer_norm(image_embeds)
+
+ # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size]
+ new_size = (
+ image_embeds.shape[0],
+ num_patches_height,
+ num_patches_width,
+ image_embeds.shape[-1],
+ )
+ image_embeds = image_embeds.reshape(new_size)
+ text_embeds = outputs[-4]
+
+ return (text_embeds, image_embeds, outputs)
+
+ # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.image_embedder with owlvit->owlv2, OwlViTModel->Owlv2Model
+ def image_embedder(
+ self,
+ pixel_values: torch.FloatTensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> tuple[torch.FloatTensor]:
+ # Get Owlv2Model vision embeddings (same as CLIP)
+ vision_outputs = self.owlv2.vision_model(
+ pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True
+ )
+
+ if interpolate_pos_encoding:
+ _, _, height, width = pixel_values.shape
+ num_patches_height = height // self.config.vision_config.patch_size
+ num_patches_width = width // self.config.vision_config.patch_size
+ else:
+ num_patches_height = self.num_patches_height
+ num_patches_width = self.num_patches_width
+
+ # Apply post_layernorm to last_hidden_state, return non-projected output
+ last_hidden_state = vision_outputs[0]
+ image_embeds = self.owlv2.vision_model.post_layernorm(last_hidden_state)
+
+ # Resize class token
+ class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
+
+ # Merge image embedding with class tokens
+ image_embeds = image_embeds[:, 1:, :] * class_token_out
+ image_embeds = self.layer_norm(image_embeds)
+
+ # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size]
+ new_size = (
+ image_embeds.shape[0],
+ num_patches_height,
+ num_patches_width,
+ image_embeds.shape[-1],
+ )
+ image_embeds = image_embeds.reshape(new_size)
+
+ return (image_embeds, vision_outputs)
+
+ # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.embed_image_query
+ def embed_image_query(
+ self,
+ query_image_features: torch.FloatTensor,
+ query_feature_map: torch.FloatTensor,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.FloatTensor:
+ _, class_embeds = self.class_predictor(query_image_features)
+ pred_boxes = self.box_predictor(query_image_features, query_feature_map, interpolate_pos_encoding)
+ pred_boxes_as_corners = center_to_corners_format(pred_boxes)
+
+ # Loop over query images
+ best_class_embeds = []
+ best_box_indices = []
+ pred_boxes_device = pred_boxes_as_corners.device
+
+ for i in range(query_image_features.shape[0]):
+ each_query_box = torch.tensor([[0, 0, 1, 1]], device=pred_boxes_device)
+ each_query_pred_boxes = pred_boxes_as_corners[i]
+ ious, _ = box_iou(each_query_box, each_query_pred_boxes)
+
+ # If there are no overlapping boxes, fall back to generalized IoU
+ if torch.all(ious[0] == 0.0):
+ ious = generalized_box_iou(each_query_box, each_query_pred_boxes)
+
+ # Use an adaptive threshold to include all boxes within 80% of the best IoU
+ iou_threshold = torch.max(ious) * 0.8
+
+ selected_inds = (ious[0] >= iou_threshold).nonzero()
+ if selected_inds.numel():
+ selected_embeddings = class_embeds[i][selected_inds.squeeze(1)]
+ mean_embeds = torch.mean(class_embeds[i], axis=0)
+ mean_sim = torch.einsum("d,id->i", mean_embeds, selected_embeddings)
+ best_box_ind = selected_inds[torch.argmin(mean_sim)]
+ best_class_embeds.append(class_embeds[i][best_box_ind])
+ best_box_indices.append(best_box_ind)
+
+ if best_class_embeds:
+ query_embeds = torch.stack(best_class_embeds)
+ box_indices = torch.stack(best_box_indices)
+ else:
+ query_embeds, box_indices = None, None
+
+ return query_embeds, box_indices, pred_boxes
+
+ @auto_docstring
+ def image_guided_detection(
+ self,
+ pixel_values: torch.FloatTensor,
+ query_pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ return_dict: Optional[bool] = None,
+ ) -> Owlv2ImageGuidedObjectDetectionOutput:
+ r"""
+ query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values of query image(s) to be detected. Pass in one query image per target image.
+
+ Examples:
+ ```python
+ >>> import requests
+ >>> from PIL import Image
+ >>> import torch
+ >>> from transformers import AutoProcessor, Owlv2ForObjectDetection
+
+ >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
+ >>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> query_url = "http://images.cocodataset.org/val2017/000000001675.jpg"
+ >>> query_image = Image.open(requests.get(query_url, stream=True).raw)
+ >>> inputs = processor(images=image, query_images=query_image, return_tensors="pt")
+
+ >>> # forward pass
+ >>> with torch.no_grad():
+ ... outputs = model.image_guided_detection(**inputs)
+
+ >>> target_sizes = torch.Tensor([image.size[::-1]])
+
+ >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
+ >>> results = processor.post_process_image_guided_detection(
+ ... outputs=outputs, threshold=0.9, nms_threshold=0.3, target_sizes=target_sizes
+ ... )
+ >>> i = 0 # Retrieve predictions for the first image
+ >>> boxes, scores = results[i]["boxes"], results[i]["scores"]
+ >>> for box, score in zip(boxes, scores):
+ ... box = [round(i, 2) for i in box.tolist()]
+ ... print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}")
+ Detected similar object with confidence 0.938 at location [327.31, 54.94, 547.39, 268.06]
+ Detected similar object with confidence 0.959 at location [5.78, 360.65, 619.12, 366.39]
+ Detected similar object with confidence 0.902 at location [2.85, 360.01, 627.63, 380.8]
+ Detected similar object with confidence 0.985 at location [176.98, -29.45, 672.69, 182.83]
+ Detected similar object with confidence 1.0 at location [6.53, 14.35, 624.87, 470.82]
+ Detected similar object with confidence 0.998 at location [579.98, 29.14, 615.49, 489.05]
+ Detected similar object with confidence 0.985 at location [206.15, 10.53, 247.74, 466.01]
+ Detected similar object with confidence 0.947 at location [18.62, 429.72, 646.5, 457.72]
+ Detected similar object with confidence 0.996 at location [523.88, 20.69, 586.84, 483.18]
+ Detected similar object with confidence 0.998 at location [3.39, 360.59, 617.29, 499.21]
+ Detected similar object with confidence 0.969 at location [4.47, 449.05, 614.5, 474.76]
+ Detected similar object with confidence 0.966 at location [31.44, 463.65, 654.66, 471.07]
+ Detected similar object with confidence 0.924 at location [30.93, 468.07, 635.35, 475.39]
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ # Compute feature maps for the input and query images
+ query_feature_map = self.image_embedder(
+ pixel_values=query_pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
+ )[0]
+ feature_map, vision_outputs = self.image_embedder(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+ batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape
+ image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim))
+
+ batch_size, num_patches_height, num_patches_width, hidden_dim = query_feature_map.shape
+ query_image_feats = torch.reshape(
+ query_feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)
+ )
+ # Get top class embedding and best box index for each query image in batch
+ query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query(
+ query_image_feats, query_feature_map, interpolate_pos_encoding
+ )
+
+ # Predict object classes [batch_size, num_patches, num_queries+1]
+ (pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_embeds=query_embeds)
+
+ # Predict object boxes
+ target_pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding)
+
+ if not return_dict:
+ output = (
+ feature_map,
+ query_feature_map,
+ target_pred_boxes,
+ query_pred_boxes,
+ pred_logits,
+ class_embeds,
+ vision_outputs.to_tuple(),
+ )
+ output = tuple(x for x in output if x is not None)
+ return output
+
+ return Owlv2ImageGuidedObjectDetectionOutput(
+ image_embeds=feature_map,
+ query_image_embeds=query_feature_map,
+ target_pred_boxes=target_pred_boxes,
+ query_pred_boxes=query_pred_boxes,
+ logits=pred_logits,
+ class_embeds=class_embeds,
+ text_model_output=None,
+ vision_model_output=vision_outputs,
+ )
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ pixel_values: torch.FloatTensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ return_dict: Optional[bool] = None,
+ ) -> Owlv2ObjectDetectionOutput:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`, *optional*):
+ Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
+ IDs?](../glossary#input-ids).
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the last hidden state. See `text_model_last_hidden_state` and
+ `vision_model_last_hidden_state` under returned tensors for more detail.
+
+ Examples:
+ ```python
+ >>> import requests
+ >>> from PIL import Image
+ >>> import torch
+
+ >>> from transformers import Owlv2Processor, Owlv2ForObjectDetection
+
+ >>> processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
+ >>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> text_labels = [["a photo of a cat", "a photo of a dog"]]
+ >>> inputs = processor(text=text_labels, images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
+ >>> target_sizes = torch.tensor([(image.height, image.width)])
+ >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
+ >>> results = processor.post_process_grounded_object_detection(
+ ... outputs=outputs, target_sizes=target_sizes, threshold=0.1, text_labels=text_labels
+ ... )
+ >>> # Retrieve predictions for the first image for the corresponding text queries
+ >>> result = results[0]
+ >>> boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"]
+ >>> for box, score, text_label in zip(boxes, scores, text_labels):
+ ... box = [round(i, 2) for i in box.tolist()]
+ ... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}")
+ Detected a photo of a cat with confidence 0.614 at location [341.67, 23.39, 642.32, 371.35]
+ Detected a photo of a cat with confidence 0.665 at location [6.75, 51.96, 326.62, 473.13]
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ # Embed images and text queries
+ query_embeds, feature_map, outputs = self.image_text_embedder(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+ # Text and vision model outputs
+ text_outputs = outputs.text_model_output
+ vision_outputs = outputs.vision_model_output
+
+ batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape
+ image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim))
+
+ # Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
+ max_text_queries = input_ids.shape[0] // batch_size
+ query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1])
+
+ # If first token is 0, then this is a padded query [batch_size, num_queries].
+ input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1])
+ query_mask = input_ids[..., 0] > 0
+
+ # Predict object classes [batch_size, num_patches, num_queries+1]
+ (pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask)
+
+ # Predict objectness
+ objectness_logits = self.objectness_predictor(image_feats)
+
+ # Predict object boxes
+ pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding)
+
+ if not return_dict:
+ output = (
+ pred_logits,
+ objectness_logits,
+ pred_boxes,
+ query_embeds,
+ feature_map,
+ class_embeds,
+ text_outputs.to_tuple(),
+ vision_outputs.to_tuple(),
+ )
+ output = tuple(x for x in output if x is not None)
+ return output
+
+ return Owlv2ObjectDetectionOutput(
+ image_embeds=feature_map,
+ text_embeds=query_embeds,
+ pred_boxes=pred_boxes,
+ logits=pred_logits,
+ objectness_logits=objectness_logits,
+ class_embeds=class_embeds,
+ text_model_output=text_outputs,
+ vision_model_output=vision_outputs,
+ )
+
+
+__all__ = ["Owlv2Model", "Owlv2PreTrainedModel", "Owlv2TextModel", "Owlv2VisionModel", "Owlv2ForObjectDetection"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/modular_owlv2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/modular_owlv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..66acd2088399c74751edea26e7f307f380890926
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/modular_owlv2.py
@@ -0,0 +1,228 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for OWLv2."""
+
+import warnings
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ BatchFeature,
+ DefaultFastImageProcessorKwargs,
+)
+from ...image_transforms import group_images_by_shape, reorder_images
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ auto_docstring,
+)
+from ..owlvit.image_processing_owlvit_fast import OwlViTImageProcessorFast
+
+
+class Owlv2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): ...
+
+
+@auto_docstring
+class Owlv2ImageProcessorFast(OwlViTImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ size = {"height": 960, "width": 960}
+ rescale_factor = 1 / 255
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_pad = True
+ valid_kwargs = Owlv2FastImageProcessorKwargs
+ crop_size = None
+ do_center_crop = None
+
+ def __init__(self, **kwargs: Unpack[Owlv2FastImageProcessorKwargs]):
+ BaseImageProcessorFast.__init__(self, **kwargs)
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[Owlv2FastImageProcessorKwargs]):
+ return BaseImageProcessorFast.preprocess(self, images, **kwargs)
+
+ def _pad_images(self, images: "torch.Tensor", constant_value: float = 0.5) -> "torch.Tensor":
+ """
+ Pad an image with zeros to the given size.
+ """
+ height, width = images.shape[-2:]
+ size = max(height, width)
+ pad_bottom = size - height
+ pad_right = size - width
+
+ padding = (0, 0, pad_right, pad_bottom)
+ padded_image = F.pad(images, padding, fill=constant_value)
+ return padded_image
+
+ def pad(
+ self,
+ images: list["torch.Tensor"],
+ disable_grouping: Optional[bool],
+ constant_value: float = 0.5,
+ **kwargs,
+ ) -> list["torch.Tensor"]:
+ """
+ Unlike the Base class `self.pad` where all images are padded to the maximum image size,
+ Owlv2 pads an image to square.
+ """
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ stacked_images = self._pad_images(
+ stacked_images,
+ constant_value=constant_value,
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+
+ return processed_images
+
+ def resize(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ anti_aliasing: bool = True,
+ anti_aliasing_sigma=None,
+ **kwargs,
+ ) -> "torch.Tensor":
+ """
+ Resize an image as per the original implementation.
+
+ Args:
+ image (`Tensor`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary containing the height and width to resize the image to.
+ anti_aliasing (`bool`, *optional*, defaults to `True`):
+ Whether to apply anti-aliasing when downsampling the image.
+ anti_aliasing_sigma (`float`, *optional*, defaults to `None`):
+ Standard deviation for Gaussian kernel when downsampling the image. If `None`, it will be calculated
+ automatically.
+ """
+ output_shape = (size.height, size.width)
+
+ input_shape = image.shape
+
+ # select height and width from input tensor
+ factors = torch.tensor(input_shape[2:]).to(image.device) / torch.tensor(output_shape).to(image.device)
+
+ if anti_aliasing:
+ if anti_aliasing_sigma is None:
+ anti_aliasing_sigma = ((factors - 1) / 2).clamp(min=0)
+ else:
+ anti_aliasing_sigma = torch.atleast_1d(anti_aliasing_sigma) * torch.ones_like(factors)
+ if torch.any(anti_aliasing_sigma < 0):
+ raise ValueError("Anti-aliasing standard deviation must be greater than or equal to zero")
+ elif torch.any((anti_aliasing_sigma > 0) & (factors <= 1)):
+ warnings.warn(
+ "Anti-aliasing standard deviation greater than zero but not down-sampling along all axes"
+ )
+ if torch.any(anti_aliasing_sigma == 0):
+ filtered = image
+ else:
+ kernel_sizes = 2 * torch.ceil(3 * anti_aliasing_sigma).int() + 1
+
+ filtered = F.gaussian_blur(
+ image, (kernel_sizes[0], kernel_sizes[1]), sigma=anti_aliasing_sigma.tolist()
+ )
+
+ else:
+ filtered = image
+
+ out = F.resize(filtered, size=(size.height, size.width), antialias=False)
+
+ return out
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_pad: bool,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+
+ for shape, stacked_images in grouped_images.items():
+ # Rescale images before other operations as done in original implementation
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, False, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+
+ if do_pad:
+ processed_images = self.pad(processed_images, constant_value=0.5, disable_grouping=disable_grouping)
+
+ grouped_images, grouped_images_index = group_images_by_shape(
+ processed_images, disable_grouping=disable_grouping
+ )
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ resized_stack = self.resize(
+ image=stacked_images,
+ size=size,
+ interpolation=interpolation,
+ input_data_format=ChannelDimension.FIRST,
+ )
+ resized_images_grouped[shape] = resized_stack
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, False, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+
+__all__ = ["Owlv2ImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/processing_owlv2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/processing_owlv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..d12ee5995535a59eadaae56c1b0338ef863a42f3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/owlv2/processing_owlv2.py
@@ -0,0 +1,292 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Image/Text processor class for OWLv2
+"""
+
+import warnings
+from typing import TYPE_CHECKING, Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...processing_utils import (
+ ImagesKwargs,
+ ProcessingKwargs,
+ ProcessorMixin,
+ Unpack,
+)
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import TensorType, is_flax_available, is_tf_available, is_torch_available
+
+
+if TYPE_CHECKING:
+ from .modeling_owlv2 import Owlv2ImageGuidedObjectDetectionOutput, Owlv2ObjectDetectionOutput
+
+
+class Owlv2ImagesKwargs(ImagesKwargs, total=False):
+ query_images: Optional[ImageInput]
+
+
+class Owlv2ProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs: Owlv2ImagesKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": "max_length",
+ },
+ "images_kwargs": {},
+ "common_kwargs": {
+ "return_tensors": "np",
+ },
+ }
+
+
+class Owlv2Processor(ProcessorMixin):
+ r"""
+ Constructs an Owlv2 processor which wraps [`Owlv2ImageProcessor`]/[`Owlv2ImageProcessorFast`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`] into
+ a single processor that inherits both the image processor and tokenizer functionalities. See the
+ [`~OwlViTProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`Owlv2ImageProcessor`, `Owlv2ImageProcessorFast`]):
+ The image processor is a required input.
+ tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]):
+ The tokenizer is a required input.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")
+ tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
+
+ def __init__(self, image_processor, tokenizer, **kwargs):
+ super().__init__(image_processor, tokenizer)
+
+ # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.__call__ with OwlViT->Owlv2
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ audio=None,
+ videos=None,
+ **kwargs: Unpack[Owlv2ProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and
+ `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode:
+ the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
+ of the above two methods for more information.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`,
+ `list[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ query_images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The query image to be prepared, one query image is expected per target image to be queried. Each image
+ can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image
+ should be of shape (C, H, W), where C is a number of channels, H and W are image height and width.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **query_pixel_values** -- Pixel values of the query images to be fed to a model. Returned when `query_images` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ Owlv2ProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ query_images = output_kwargs["images_kwargs"].pop("query_images", None)
+ return_tensors = output_kwargs["common_kwargs"]["return_tensors"]
+
+ if text is None and query_images is None and images is None:
+ raise ValueError(
+ "You have to specify at least one text or query image or image. All three cannot be none."
+ )
+
+ data = {}
+ if text is not None:
+ if isinstance(text, str) or (isinstance(text, list) and not isinstance(text[0], list)):
+ encodings = [self.tokenizer(text, **output_kwargs["text_kwargs"])]
+
+ elif isinstance(text, list) and isinstance(text[0], list):
+ encodings = []
+
+ # Maximum number of queries across batch
+ max_num_queries = max(len(text_single) for text_single in text)
+
+ # Pad all batch samples to max number of text queries
+ for text_single in text:
+ if len(text_single) != max_num_queries:
+ text_single = text_single + [" "] * (max_num_queries - len(text_single))
+
+ encoding = self.tokenizer(text_single, **output_kwargs["text_kwargs"])
+ encodings.append(encoding)
+ else:
+ raise TypeError("Input text should be a string, a list of strings or a nested list of strings")
+
+ if return_tensors == "np":
+ input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0)
+ attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0)
+
+ elif return_tensors == "jax" and is_flax_available():
+ import jax.numpy as jnp
+
+ input_ids = jnp.concatenate([encoding["input_ids"] for encoding in encodings], axis=0)
+ attention_mask = jnp.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0)
+
+ elif return_tensors == "pt" and is_torch_available():
+ import torch
+
+ input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0)
+ attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0)
+
+ elif return_tensors == "tf" and is_tf_available():
+ import tensorflow as tf
+
+ input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0)
+ attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0)
+
+ else:
+ raise ValueError("Target return tensor type could not be returned")
+
+ data["input_ids"] = input_ids
+ data["attention_mask"] = attention_mask
+
+ if query_images is not None:
+ query_pixel_values = self.image_processor(query_images, **output_kwargs["images_kwargs"]).pixel_values
+ # Query images always override the text prompt
+ data = {"query_pixel_values": query_pixel_values}
+
+ if images is not None:
+ image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
+ data["pixel_values"] = image_features.pixel_values
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.post_process_object_detection with OwlViT->Owlv2
+ def post_process_object_detection(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to [`Owlv2ImageProcessor.post_process_object_detection`]. Please refer
+ to the docstring of this method for more information.
+ """
+ warnings.warn(
+ "`post_process_object_detection` method is deprecated for OwlVitProcessor and will be removed in v5. "
+ "Use `post_process_grounded_object_detection` instead.",
+ FutureWarning,
+ )
+ return self.image_processor.post_process_object_detection(*args, **kwargs)
+
+ # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.post_process_grounded_object_detection with OwlViT->Owlv2
+ def post_process_grounded_object_detection(
+ self,
+ outputs: "Owlv2ObjectDetectionOutput",
+ threshold: float = 0.1,
+ target_sizes: Optional[Union[TensorType, list[tuple]]] = None,
+ text_labels: Optional[list[list[str]]] = None,
+ ):
+ """
+ Converts the raw output of [`Owlv2ForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
+ bottom_right_x, bottom_right_y) format.
+
+ Args:
+ outputs ([`Owlv2ObjectDetectionOutput`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*, defaults to 0.1):
+ Score threshold to keep object detection predictions.
+ target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
+ text_labels (`list[list[str]]`, *optional*):
+ List of lists of text labels for each image in the batch. If unset, "text_labels" in output will be
+ set to `None`.
+
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the following keys:
+ - "scores": The confidence scores for each predicted box on the image.
+ - "labels": Indexes of the classes predicted by the model on the image.
+ - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format.
+ - "text_labels": The text labels for each predicted bounding box on the image.
+ """
+ output = self.image_processor.post_process_object_detection(
+ outputs=outputs, threshold=threshold, target_sizes=target_sizes
+ )
+
+ if text_labels is not None and len(text_labels) != len(output):
+ raise ValueError("Make sure that you pass in as many lists of text labels as images")
+
+ # adding text labels to the output
+ if text_labels is not None:
+ for image_output, image_text_labels in zip(output, text_labels):
+ object_text_labels = [image_text_labels[i] for i in image_output["labels"]]
+ image_output["text_labels"] = object_text_labels
+ else:
+ for image_output in output:
+ image_output["text_labels"] = None
+
+ return output
+
+ # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.post_process_image_guided_detection with OwlViT->Owlv2
+ def post_process_image_guided_detection(
+ self,
+ outputs: "Owlv2ImageGuidedObjectDetectionOutput",
+ threshold: float = 0.0,
+ nms_threshold: float = 0.3,
+ target_sizes: Optional[Union[TensorType, list[tuple]]] = None,
+ ):
+ """
+ Converts the output of [`Owlv2ForObjectDetection.image_guided_detection`] into the format expected by the COCO
+ api.
+
+ Args:
+ outputs ([`Owlv2ImageGuidedObjectDetectionOutput`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*, defaults to 0.0):
+ Minimum confidence threshold to use to filter out predicted boxes.
+ nms_threshold (`float`, *optional*, defaults to 0.3):
+ IoU threshold for non-maximum suppression of overlapping boxes.
+ target_sizes (`torch.Tensor`, *optional*):
+ Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
+ the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to
+ None, predictions will not be unnormalized.
+
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the following keys:
+ - "scores": The confidence scores for each predicted box on the image.
+ - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format.
+ - "labels": Set to `None`.
+ """
+ return self.image_processor.post_process_image_guided_detection(
+ outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes
+ )
+
+
+__all__ = ["Owlv2Processor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/paligemma/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/paligemma/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9048afe6adbdc0ad36007e02f60e899cae677c55
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/paligemma/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_paligemma import *
+ from .modeling_paligemma import *
+ from .processing_paligemma import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/paligemma/configuration_paligemma.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/paligemma/configuration_paligemma.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4ee4b3b45c2637d804f7eafa7f9cc7afa2638f8
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/paligemma/configuration_paligemma.py
@@ -0,0 +1,128 @@
+# coding=utf-8
+# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PaliGemmamodel configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class PaliGemmaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PaliGemmaForConditionalGeneration`]. It is used to instantiate an
+ PaliGemmamodel according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the PaliGemma-2B.
+
+ e.g. [paligemma-hf/paligemma-2b](https://huggingface.co/paligemma-hf/paligemma-2b)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`PaliGemmaVisionConfig`, *optional*):
+ Custom vision config or dict
+ text_config (`Union[AutoConfig, dict]`, *optional*):
+ The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
+ image_token_index (`int`, *optional*, defaults to 256000):
+ The image token index to encode the image prompt.
+ vocab_size (`int`, *optional*, defaults to 257152):
+ Vocabulary size of the PaliGemmamodel. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`~PaliGemmaForConditionalGeneration`]
+ projection_dim (`int`, *optional*, defaults to 2048):
+ Dimension of the multimodal projection space.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden layer of the Language model.
+
+ Example:
+
+ ```python
+ >>> from transformers import PaliGemmaForConditionalGeneration, PaliGemmaConfig, SiglipVisionConfig, GemmaConfig
+
+ >>> # Initializing a Siglip-like vision config
+ >>> vision_config = SiglipVisionConfig()
+
+ >>> # Initializing a PaliGemma config
+ >>> text_config = GemmaConfig()
+
+ >>> # Initializing a PaliGemma paligemma-3b-224 style configuration
+ >>> configuration = PaliGemmaConfig(vision_config, text_config)
+
+ >>> # Initializing a model from the paligemma-3b-224 style configuration
+ >>> model = PaliGemmaForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "paligemma"
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ }
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ image_token_index=256000,
+ vocab_size=257152,
+ projection_dim=2048,
+ hidden_size=2048,
+ **kwargs,
+ ):
+ self.image_token_index = image_token_index
+ self.projection_dim = projection_dim
+ self.hidden_size = hidden_size
+ self.vision_config = vision_config
+ self.is_encoder_decoder = False
+
+ if isinstance(self.vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "siglip_vision_model")
+ self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ elif vision_config is None:
+ self.vision_config = CONFIG_MAPPING["siglip_vision_model"](
+ intermediate_size=4096,
+ hidden_size=1152,
+ patch_size=14,
+ image_size=224,
+ num_hidden_layers=27,
+ num_attention_heads=16,
+ vocab_size=257152,
+ vision_use_head=False,
+ )
+
+ self.text_config = text_config
+ if isinstance(self.text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "gemma")
+ self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ self.text_config = CONFIG_MAPPING["gemma"](
+ hidden_size=2048,
+ num_hidden_layers=18,
+ intermediate_size=16384,
+ num_attention_heads=8,
+ num_key_value_heads=1,
+ is_encoder_decoder=False,
+ vocab_size=vocab_size,
+ )
+ self.text_config.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2
+ self.vision_config.projection_dim = projection_dim
+ super().__init__(**kwargs)
+
+
+__all__ = ["PaliGemmaConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/paligemma/modeling_paligemma.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/paligemma/modeling_paligemma.py
new file mode 100644
index 0000000000000000000000000000000000000000..abd8595e24abb7d3850a19144ebe67707b85f7ae
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/paligemma/modeling_paligemma.py
@@ -0,0 +1,625 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch PaliGemmamodel."""
+
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...cache_utils import Cache, StaticCache
+from ...generation import GenerationMixin
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ ModelOutput,
+ TransformersKwargs,
+ auto_docstring,
+ can_return_tuple,
+ logging,
+)
+from ..auto import AutoModel
+from .configuration_paligemma import PaliGemmaConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Paligemma outputs, with hidden states and attentions.
+ """
+)
+class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for PaliGemma causal language model (or autoregressive) outputs.
+ """
+)
+class PaliGemmaCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+class PaliGemmaMultiModalProjector(nn.Module):
+ def __init__(self, config: PaliGemmaConfig):
+ super().__init__()
+ self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
+
+ def forward(self, image_features):
+ hidden_states = self.linear(image_features)
+
+ return hidden_states
+
+
+@auto_docstring
+class PaliGemmaPreTrainedModel(PreTrainedModel):
+ config: PaliGemmaConfig
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["PaliGemmaMultiModalProjector"]
+ _skip_keys_device_placement = "past_key_values"
+
+ _can_compile_fullgraph = False
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ # important: this ported version of PaliGemmaisn't meant for training from scratch - only
+ # inference and fine-tuning
+ std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
+
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+@auto_docstring(
+ custom_intro="""
+ The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
+ """
+)
+class PaliGemmaModel(PaliGemmaPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+ # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
+ accepts_loss_kwargs = False
+
+ def __init__(self, config: PaliGemmaConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
+ self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
+ self.vocab_size = config.text_config.vocab_size
+
+ language_model = AutoModel.from_config(config=config.text_config)
+ self.language_model = language_model
+
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ self.text_config_dtype = self.config.get_text_config().dtype or self.dtype
+ self.post_init()
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def _update_causal_mask(
+ self,
+ attention_mask,
+ token_type_ids=None,
+ past_key_values=None,
+ cache_position=None,
+ input_tensor=None,
+ is_training: Optional[bool] = None,
+ ):
+ if self.config.text_config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+ is_training = is_training if is_training is not None else self.training
+ using_static_cache = isinstance(past_key_values, StaticCache)
+ min_dtype = torch.finfo(self.text_config_dtype).min
+ if input_tensor is None:
+ input_tensor = attention_mask
+
+ inputs_lead_dim, sequence_length = input_tensor.shape[:2]
+ if using_static_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else cache_position[0] + sequence_length + 1
+ )
+
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ return attention_mask
+
+ causal_mask = torch.full(
+ (sequence_length, target_length),
+ fill_value=min_dtype,
+ dtype=self.text_config_dtype,
+ device=cache_position.device,
+ )
+ # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
+ if sequence_length != 1:
+ if is_training:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ else:
+ causal_mask[:, :sequence_length] = 0.0
+
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+
+ # First unmask prefix tokens during training
+ if is_training:
+ if token_type_ids is None:
+ raise ValueError("Token type ids must be provided during training")
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
+ )
+
+ # Then apply padding mask (will mask pad tokens)
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+ def get_image_features(self, pixel_values: torch.FloatTensor):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
+ The tensors corresponding to the input images.
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ image_outputs = self.vision_tower(pixel_values)
+ selected_image_feature = image_outputs.last_hidden_state
+ image_features = self.multi_modal_projector(selected_image_feature)
+ image_features = image_features / (self.config.text_config.hidden_size**0.5)
+ return image_features
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ return special_image_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, PaligemmaModelOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
+
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
+ >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
+
+ >>> prompt = "Where is the cat standing?"
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs,)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Where is the cat standing?\nsnow"
+ ```"""
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ is_training = token_type_ids is not None and labels is not None
+
+ # Replace image id with PAD if the image token if OOV, to avoid index-errors
+ if input_ids is not None and self.config.image_token_id >= self.vocab_size:
+ special_image_mask = input_ids == self.config.image_token_id
+ llm_input_ids = input_ids.clone()
+ llm_input_ids[special_image_mask] = 0
+ else:
+ llm_input_ids = input_ids
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
+
+ # Merge text and images
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values)
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
+ )
+ outputs = self.language_model(
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return PaligemmaModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
+ """
+)
+class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: PaliGemmaConfig):
+ super().__init__(config)
+ self.model = PaliGemmaModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_image_features(self, pixel_values):
+ return self.model.get_image_features(pixel_values)
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, PaliGemmaCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
+
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
+ >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
+
+ >>> prompt = "Where is the cat standing?"
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs,)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Where is the cat standing?\nsnow"
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ labels=labels,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return PaliGemmaCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ pixel_values=None,
+ attention_mask=None,
+ token_type_ids=None,
+ use_cache=True,
+ logits_to_keep=None,
+ labels=None,
+ **kwargs,
+ ):
+ # Overwritten -- custom `position_ids` and `pixel_values` handling
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ use_cache=use_cache,
+ logits_to_keep=logits_to_keep,
+ token_type_ids=token_type_ids,
+ **kwargs,
+ )
+
+ # position_ids in Paligemma are 1-indexed
+ if model_inputs.get("position_ids") is not None:
+ model_inputs["position_ids"] += 1
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
+ if cache_position[0] == 0:
+ model_inputs["pixel_values"] = pixel_values
+ is_training = token_type_ids is not None and labels is not None
+ is_static_hybrid_cache = isinstance(past_key_values, StaticCache) and any(past_key_values.is_sliding)
+ if cache_position[0] == 0 and is_static_hybrid_cache:
+ input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
+ causal_mask = self.model._update_causal_mask(
+ attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
+ )
+ model_inputs["attention_mask"] = causal_mask
+
+ return model_inputs
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/paligemma/processing_paligemma.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/paligemma/processing_paligemma.py
new file mode 100644
index 0000000000000000000000000000000000000000..242627a0eb71a98f71ca0aadf260cc51df3fb085
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/paligemma/processing_paligemma.py
@@ -0,0 +1,336 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for PaliGemma.
+"""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput, is_valid_image
+from ...processing_utils import (
+ ImagesKwargs,
+ MultiModalData,
+ ProcessingKwargs,
+ ProcessorMixin,
+ TextKwargs,
+ Unpack,
+)
+from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+IMAGE_TOKEN = ""
+EXTRA_TOKENS = [f"4}>" for i in range(1024)] + [f"3}>" for i in range(128)]
+
+
+class PaliGemmaTextKwargs(TextKwargs):
+ suffix: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]]
+
+
+class PaliGemmaImagesKwargs(ImagesKwargs):
+ do_convert_rgb: Optional[bool]
+
+
+class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False):
+ text_kwargs: PaliGemmaTextKwargs
+ images_kwargs: PaliGemmaImagesKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "return_mm_token_type_ids": False,
+ },
+ "images_kwargs": {
+ "data_format": "channels_first",
+ },
+ }
+
+
+# Copied from transformers.models.idefics2.processing_idefics2.is_url
+def is_url(val) -> bool:
+ return isinstance(val, str) and val.startswith("http")
+
+
+# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
+def is_image_or_image_url(elem):
+ return is_url(elem) or is_valid_image(elem)
+
+
+def _is_str_or_image(elem):
+ return isinstance(elem, (str)) or is_image_or_image_url(elem)
+
+
+def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images):
+ """
+ Builds a string from the input prompt and image tokens.
+ For example, for the call:
+ build_string_from_input(
+ prompt="Prefix str"
+ bos_token="",
+ image_seq_len=3,
+ image_token="",
+ )
+ The output will be:
+ "Initial str"
+ Args:
+ prompt (`list[Union[str, ImageInput]]`): The input prompt.
+ bos_token (`str`): The beginning of sentence token.
+ image_seq_len (`int`): The length of the image sequence.
+ image_token (`str`): The image token.
+ num_images (`int`): Number of images in the prompt.
+ """
+ return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n"
+
+
+class PaliGemmaProcessor(ProcessorMixin):
+ r"""
+ Constructs a PaliGemma processor which wraps a PaliGemma image processor and a PaliGemma tokenizer into a single processor.
+
+ [`PaliGemmaProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`GemmaTokenizerFast`]. See the
+ [`~PaliGemmaProcessor.__call__`] and [`~PaliGemmaProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`SiglipImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`GemmaTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = ("SiglipImageProcessor", "SiglipImageProcessorFast")
+ tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast")
+
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer=None,
+ chat_template=None,
+ **kwargs,
+ ):
+ if not hasattr(image_processor, "image_seq_length"):
+ raise ValueError("Image processor is missing an `image_seq_length` attribute.")
+
+ self.image_seq_length = image_processor.image_seq_length
+
+ if not hasattr(tokenizer, "image_token"):
+ image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
+ tokens_to_add = {"additional_special_tokens": [image_token]}
+ tokenizer.add_special_tokens(tokens_to_add)
+ self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
+ self.image_token = IMAGE_TOKEN
+ else:
+ self.image_token_id = tokenizer.image_token_id
+ self.image_token = tokenizer.image_token
+
+ tokenizer.add_tokens(EXTRA_TOKENS)
+ tokenizer.add_bos_token = False
+ tokenizer.add_eos_token = False
+
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ audio=None,
+ videos=None,
+ **kwargs: Unpack[PaliGemmaProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to GemmaTokenizerFast's [`~GemmaTokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
+ SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
+ of the above two methods for more information.
+
+ The usage for PaliGemma fine-tuning preparation is slightly different than usual. suffix passed are suffixes to
+ the prompt in `text`, and will be placed after the prompt. This is because attention is handled differently for
+ the prefix and the suffix. For instance,
+ ```python
+ image = PIL_cow_image
+ prompt = "answer en Where is the cow standing?"
+ suffix = "on the beach"
+ inputs = processor(text=prompt, images=image, suffix=suffix)
+ ```
+ Here `inputs` will contain the `input_ids` and `token_type_ids` that follow
+ ```python
+ inputs["input_ids"][:, 256:]
+ # tensor([[ 2, 6006, 603, 573, 13910, 9980, 235336, 108, 477, 573, 8318]])
+ inputs["token_type_ids"][:, 256:]
+ tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]])
+ ```
+ Meaning the last three tokens are of "label" ("suffix") type while the other ones are of "prefix" type.
+
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+ suffix (`str`, `list[str]`, `list[list[str]]`):
+ The suffixes or batch of suffixes to be encoded. Only necessary for finetuning. See https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md
+ for more information. If your prompt is " What is on the image", the suffix corresponds to the expected prediction "a cow sitting on a bench".
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix`
+ is provided, the `input_ids` will also contain the suffix input ids.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **labels** -- Labels compatible with training if `suffix` is not None
+ """
+
+ output_kwargs = self._merge_kwargs(
+ PaliGemmaProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ suffix = output_kwargs["text_kwargs"].pop("suffix", None)
+
+ return_token_type_ids = suffix is not None
+
+ if images is None:
+ raise ValueError("`images` are expected as arguments to a `PaliGemmaProcessor` instance.")
+ if text is None:
+ logger.warning_once(
+ "You are using PaliGemma without a text prefix. It will perform as a picture-captioning model."
+ )
+ text = ""
+
+ if _is_str_or_image(text):
+ text = [text]
+ elif isinstance(text, list) and _is_str_or_image(text[0]):
+ pass
+
+ if text is not None and images is not None:
+ if not any(IMAGE_TOKEN in sample for sample in text):
+ logger.warning(
+ "You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special "
+ "image tokens in the text, as many tokens as there are images per each text. It is recommended to "
+ "add `` tokens in the very beginning of your text. For this call, we will infer how many images "
+ "each text has and add special tokens."
+ )
+
+ if isinstance(text, list) and isinstance(images, list):
+ if len(images) != len(text):
+ raise ValueError(
+ f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image or list of images."
+ )
+
+ # make a nested list of lists to be able to iterate over the images and text below
+ if is_valid_image(images):
+ images = [[images]]
+ elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
+ images = [[image] for image in images]
+ elif not (
+ isinstance(images, (list, tuple))
+ and isinstance(images[0], (list, tuple))
+ and is_valid_image(images[0][0])
+ ):
+ raise ValueError("images must be an image, list of images or list of list of images")
+
+ input_strings = [
+ build_string_from_input(
+ prompt=prompt,
+ bos_token=self.tokenizer.bos_token,
+ image_seq_len=self.image_seq_length,
+ image_token=IMAGE_TOKEN,
+ num_images=len(image_list) if isinstance(image_list, list) else 1,
+ )
+ for prompt, image_list in zip(text, images)
+ ]
+ else:
+ expanded_samples = []
+ for sample in text:
+ expanded_sample = sample.replace(IMAGE_TOKEN, IMAGE_TOKEN * self.image_seq_length)
+ bos_rfind_index = expanded_sample.rfind(IMAGE_TOKEN)
+ bos_index = bos_rfind_index + len(IMAGE_TOKEN) if bos_rfind_index != -1 else 0
+ expanded_sample = (
+ expanded_sample[:bos_index] + self.tokenizer.bos_token + expanded_sample[bos_index:]
+ )
+ expanded_samples.append(expanded_sample)
+ input_strings = [f"{sample}\n" for sample in expanded_samples]
+
+ if suffix is not None and _is_str_or_image(suffix):
+ suffix = [suffix]
+ if suffix is not None:
+ suffix = [sfx + self.tokenizer.eos_token for sfx in suffix]
+ pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
+ inputs = self.tokenizer(
+ input_strings,
+ text_pair=suffix,
+ return_token_type_ids=return_token_type_ids,
+ **output_kwargs["text_kwargs"],
+ )
+ self._check_special_mm_tokens(input_strings, inputs, modalities=["image"])
+
+ return_data = {**inputs, "pixel_values": pixel_values}
+
+ if return_token_type_ids:
+ labels = np.array(inputs["input_ids"])
+ labels[np.array(inputs["token_type_ids"]) == 0] = -100
+ return_data.update({"labels": labels})
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(return_data["input_ids"])
+ mm_token_type_ids = np.zeros_like(return_data["input_ids"])
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
+ return_data["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data=return_data, tensor_type=return_tensors)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+
+ Args:
+ image_sizes (list[list[str]], *optional*):
+ The input sizes formatted as (height, width) per each image.
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+ vision_data = {}
+ if image_sizes is not None:
+ num_image_tokens = [self.image_seq_length] * len(image_sizes)
+ num_image_patches = [1] * len(image_sizes)
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+ return MultiModalData(**vision_data)
+
+
+__all__ = ["PaliGemmaProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/patchtsmixer/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/patchtsmixer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..285c1970308a47827806fca349d130703f40a2c8
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/patchtsmixer/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_patchtsmixer import *
+ from .modeling_patchtsmixer import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/patchtsmixer/configuration_patchtsmixer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/patchtsmixer/configuration_patchtsmixer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9725bd754634272da0538328f533724091214a2f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/patchtsmixer/configuration_patchtsmixer.py
@@ -0,0 +1,235 @@
+# coding=utf-8
+# Copyright 2023 IBM and HuggingFace Inc. team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PatchTSMixer model configuration"""
+
+from typing import Optional, Union
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class PatchTSMixerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PatchTSMixerModel`]. It is used to instantiate a
+ PatchTSMixer model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the PatchTSMixer
+ [ibm/patchtsmixer-etth1-pretrain](https://huggingface.co/ibm/patchtsmixer-etth1-pretrain) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ context_length (`int`, *optional*, defaults to 32):
+ The context/history length for the input sequence.
+ patch_length (`int`, *optional*, defaults to 8):
+ The patch length for the input sequence.
+ num_input_channels (`int`, *optional*, defaults to 1):
+ Number of input variates. For Univariate, set it to 1.
+ patch_stride (`int`, *optional*, defaults to 8):
+ Determines the overlap between two consecutive patches. Set it to patch_length (or greater), if we want
+ non-overlapping patches.
+ num_parallel_samples (`int`, *optional*, defaults to 100):
+ The number of samples to generate in parallel for probabilistic forecast.
+ d_model (`int`, *optional*, defaults to 8):
+ Hidden dimension of the model. Recommended to set it as a multiple of patch_length (i.e. 2-5X of
+ patch_length). Larger value indicates more complex model.
+ expansion_factor (`int`, *optional*, defaults to 2):
+ Expansion factor to use inside MLP. Recommended range is 2-5. Larger value indicates more complex model.
+ num_layers (`int`, *optional*, defaults to 3):
+ Number of layers to use. Recommended range is 3-15. Larger value indicates more complex model.
+ dropout (`float`, *optional*, defaults to 0.2):
+ The dropout probability the `PatchTSMixer` backbone. Recommended range is 0.2-0.7
+ mode (`str`, *optional*, defaults to `"common_channel"`):
+ Mixer Mode. Determines how to process the channels. Allowed values: "common_channel", "mix_channel". In
+ "common_channel" mode, we follow Channel-independent modelling with no explicit channel-mixing. Channel
+ mixing happens in an implicit manner via shared weights across channels. (preferred first approach) In
+ "mix_channel" mode, we follow explicit channel-mixing in addition to patch and feature mixer. (preferred
+ approach when channel correlations are very important to model)
+ gated_attn (`bool`, *optional*, defaults to `True`):
+ Enable Gated Attention.
+ norm_mlp (`str`, *optional*, defaults to `"LayerNorm"`):
+ Normalization layer (BatchNorm or LayerNorm).
+ self_attn (`bool`, *optional*, defaults to `False`):
+ Enable Tiny self attention across patches. This can be enabled when the output of Vanilla PatchTSMixer with
+ gated attention is not satisfactory. Enabling this leads to explicit pair-wise attention and modelling
+ across patches.
+ self_attn_heads (`int`, *optional*, defaults to 1):
+ Number of self-attention heads. Works only when `self_attn` is set to `True`.
+ use_positional_encoding (`bool`, *optional*, defaults to `False`):
+ Enable the use of positional embedding for the tiny self-attention layers. Works only when `self_attn` is
+ set to `True`.
+ positional_encoding_type (`str`, *optional*, defaults to `"sincos"`):
+ Positional encodings. Options `"random"` and `"sincos"` are supported. Works only when
+ `use_positional_encoding` is set to `True`
+ scaling (`string` or `bool`, *optional*, defaults to `"std"`):
+ Whether to scale the input targets via "mean" scaler, "std" scaler or no scaler if `None`. If `True`, the
+ scaler is set to "mean".
+ loss (`string`, *optional*, defaults to `"mse"`):
+ The loss function for the model corresponding to the `distribution_output` head. For parametric
+ distributions it is the negative log likelihood ("nll") and for point estimates it is the mean squared
+ error "mse".
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated normal weight initialization distribution.
+ post_init (`bool`, *optional*, defaults to `False`):
+ Whether to use custom weight initialization from `transformers` library, or the default initialization in
+ `PyTorch`. Setting it to `False` performs `PyTorch` weight initialization.
+ norm_eps (`float`, *optional*, defaults to 1e-05):
+ A value added to the denominator for numerical stability of normalization.
+ mask_type (`str`, *optional*, defaults to `"random"`):
+ Type of masking to use for Masked Pretraining mode. Allowed values are "random", "forecast". In Random
+ masking, points are masked randomly. In Forecast masking, points are masked towards the end.
+ random_mask_ratio (`float`, *optional*, defaults to 0.5):
+ Masking ratio to use when `mask_type` is `random`. Higher value indicates more masking.
+ num_forecast_mask_patches (`int` or `list`, *optional*, defaults to `[2]`):
+ Number of patches to be masked at the end of each batch sample. If it is an integer, all the samples in the
+ batch will have the same number of masked patches. If it is a list, samples in the batch will be randomly
+ masked by numbers defined in the list. This argument is only used for forecast pretraining.
+ mask_value (`float`, *optional*, defaults to `0.0`):
+ Mask value to use.
+ masked_loss (`bool`, *optional*, defaults to `True`):
+ Whether to compute pretraining loss only at the masked portions, or on the entire output.
+ channel_consistent_masking (`bool`, *optional*, defaults to `True`):
+ When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
+ across channels.
+ unmasked_channel_indices (`list`, *optional*):
+ Channels that are not masked during pretraining.
+ head_dropout (`float`, *optional*, defaults to 0.2):
+ The dropout probability the `PatchTSMixer` head.
+ distribution_output (`string`, *optional*, defaults to `"student_t"`):
+ The distribution emission head for the model when loss is "nll". Could be either "student_t", "normal" or
+ "negative_binomial".
+ prediction_length (`int`, *optional*, defaults to 16):
+ Number of time steps to forecast for a forecasting task. Also known as the Forecast Horizon.
+ prediction_channel_indices (`list`, *optional*):
+ List of channel indices to forecast. If None, forecast all channels. Target data is expected to have all
+ channels and we explicitly filter the channels in prediction and target before loss computation.
+ num_targets (`int`, *optional*, defaults to 3):
+ Number of targets (dimensionality of the regressed variable) for a regression task.
+ output_range (`list`, *optional*):
+ Output range to restrict for the regression task. Defaults to None.
+ head_aggregation (`str`, *optional*, defaults to `"max_pool"`):
+ Aggregation mode to enable for classification or regression task. Allowed values are `None`, "use_last",
+ "max_pool", "avg_pool".
+
+ Example:
+
+ ```python
+ >>> from transformers import PatchTSMixerConfig, PatchTSMixerModel
+
+ >>> # Initializing a default PatchTSMixer configuration
+ >>> configuration = PatchTSMixerConfig()
+
+ >>> # Randomly initializing a model (with random weights) from the configuration
+ >>> model = PatchTSMixerModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "patchtsmixer"
+ attribute_map = {
+ "hidden_size": "d_model",
+ "num_hidden_layers": "num_layers",
+ }
+
+ def __init__(
+ self,
+ # Time series specific configuration
+ context_length: int = 32,
+ patch_length: int = 8,
+ num_input_channels: int = 1,
+ patch_stride: int = 8,
+ num_parallel_samples: int = 100,
+ # General model configuration
+ d_model: int = 8,
+ expansion_factor: int = 2,
+ num_layers: int = 3,
+ dropout: float = 0.2,
+ mode: str = "common_channel",
+ gated_attn: bool = True,
+ norm_mlp: str = "LayerNorm",
+ self_attn: bool = False,
+ self_attn_heads: int = 1,
+ use_positional_encoding: bool = False,
+ positional_encoding_type: str = "sincos",
+ scaling: Optional[Union[str, bool]] = "std",
+ loss: str = "mse",
+ init_std: float = 0.02,
+ post_init: bool = False,
+ norm_eps: float = 1e-5,
+ # Pretrain model configuration
+ mask_type: str = "random",
+ random_mask_ratio: float = 0.5,
+ num_forecast_mask_patches: Optional[Union[list[int], int]] = [2],
+ mask_value: int = 0,
+ masked_loss: bool = True,
+ channel_consistent_masking: bool = True,
+ unmasked_channel_indices: Optional[list[int]] = None,
+ # General head configuration
+ head_dropout: float = 0.2,
+ distribution_output: str = "student_t",
+ # Prediction head configuration
+ prediction_length: int = 16,
+ prediction_channel_indices: Optional[list] = None,
+ # Classification/Regression configuration
+ num_targets: int = 3,
+ output_range: Optional[list] = None,
+ head_aggregation: str = "max_pool",
+ **kwargs,
+ ):
+ self.num_input_channels = num_input_channels
+ self.context_length = context_length
+ self.patch_length = patch_length
+ self.patch_stride = patch_stride
+ self.d_model = d_model
+ self.expansion_factor = expansion_factor
+ self.num_layers = num_layers
+ self.dropout = dropout
+ self.mode = mode
+ self.gated_attn = gated_attn
+ self.norm_mlp = norm_mlp
+ self.scaling = scaling
+ self.head_dropout = head_dropout
+ self.num_patches = (max(context_length, patch_length) - patch_length) // patch_stride + 1
+ self.mask_type = mask_type
+ self.random_mask_ratio = random_mask_ratio
+ self.num_forecast_mask_patches = num_forecast_mask_patches
+ self.mask_value = mask_value
+ self.channel_consistent_masking = channel_consistent_masking
+ self.masked_loss = masked_loss
+ self.patch_last = True
+ self.use_positional_encoding = use_positional_encoding
+ self.positional_encoding_type = positional_encoding_type
+ self.prediction_length = prediction_length
+ self.prediction_channel_indices = prediction_channel_indices
+ self.num_targets = num_targets
+ self.output_range = output_range
+ self.head_aggregation = head_aggregation
+ self.self_attn = self_attn
+ self.self_attn_heads = self_attn_heads
+ self.init_std = init_std
+ self.post_init = post_init
+ self.distribution_output = distribution_output
+ self.loss = loss
+ self.num_parallel_samples = num_parallel_samples
+ self.unmasked_channel_indices = unmasked_channel_indices
+ self.norm_eps = norm_eps
+ super().__init__(**kwargs)
+
+
+__all__ = ["PatchTSMixerConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/patchtsmixer/modeling_patchtsmixer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b830a516804e7aa0d3f5fcfd2800f5e95cce7354
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/patchtsmixer/modeling_patchtsmixer.py
@@ -0,0 +1,2120 @@
+# coding=utf-8
+# Copyright 2023 IBM and HuggingFace Inc. team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch PatchTSMixer model."""
+
+import math
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import ModelOutput
+
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
+from ...utils import auto_docstring, logging
+from .configuration_patchtsmixer import PatchTSMixerConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class PatchTSMixerGatedAttention(nn.Module):
+ """
+ Module that applies gated attention to input data.
+
+ Args:
+ in_size (`int`): The input size.
+ out_size (`int`): The output size.
+ """
+
+ def __init__(self, in_size: int, out_size: int):
+ super().__init__()
+ self.attn_layer = nn.Linear(in_size, out_size)
+ self.attn_softmax = nn.Softmax(dim=-1)
+
+ def forward(self, inputs):
+ attn_weight = self.attn_softmax(self.attn_layer(inputs))
+ inputs = inputs * attn_weight
+ return inputs
+
+
+# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTBatchNorm with PatchTST->PatchTSMixer
+class PatchTSMixerBatchNorm(nn.Module):
+ """
+ Compute batch normalization over the sequence length (time) dimension.
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__()
+ self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps)
+
+ def forward(self, inputs: torch.Tensor):
+ """
+ Parameters:
+ inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
+ input for Batch norm calculation
+ Returns:
+ `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
+ """
+ output = inputs.transpose(1, 2) # output: (batch_size, d_model, sequence_length)
+ output = self.batchnorm(output)
+ return output.transpose(1, 2)
+
+
+class PatchTSMixerPositionalEncoding(nn.Module):
+ """
+ Class for positional encoding
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__()
+ # positional encoding: [num_patches x d_model]
+ if config.use_positional_encoding:
+ self.position_enc = self._init_pe(config)
+ else:
+ self.position_enc = nn.Parameter(torch.zeros(config.num_patches, config.d_model))
+
+ @staticmethod
+ def _init_pe(config: PatchTSMixerConfig) -> nn.Parameter:
+ # Positional encoding
+ if config.positional_encoding_type == "random":
+ position_enc = nn.Parameter(torch.randn(config.num_patches, config.d_model), requires_grad=True)
+ elif config.positional_encoding_type == "sincos":
+ position_enc = torch.zeros(config.num_patches, config.d_model)
+ position = torch.arange(0, config.num_patches).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model))
+ position_enc[:, 0::2] = torch.sin(position * div_term)
+ position_enc[:, 1::2] = torch.cos(position * div_term)
+ position_enc = position_enc - position_enc.mean()
+ position_enc = position_enc / (position_enc.std() * 10)
+ position_enc = nn.Parameter(position_enc, requires_grad=False)
+ else:
+ raise ValueError(
+ f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
+ )
+ return position_enc
+
+ def forward(self, patch_input: torch.Tensor):
+ # hidden_state: [bs x num_channels x num_patches x d_model]
+ hidden_state = patch_input + self.position_enc
+ return hidden_state
+
+
+class PatchTSMixerNormLayer(nn.Module):
+ """Normalization block
+
+ Args:
+ config (`PatchTSMixerConfig`):
+ Configuration.
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__()
+
+ self.norm_mlp = config.norm_mlp
+
+ if "batch" in config.norm_mlp.lower():
+ self.norm = PatchTSMixerBatchNorm(config)
+ else:
+ self.norm = nn.LayerNorm(config.d_model, eps=config.norm_eps)
+
+ def forward(self, inputs: torch.Tensor):
+ """
+ Args:
+ inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
+ Input to the normalization layer.
+ Returns:
+ `torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`
+ """
+ if "batch" in self.norm_mlp.lower():
+ # reshape the data
+ inputs_reshaped = torch.reshape(
+ inputs,
+ (
+ inputs.shape[0] * inputs.shape[1],
+ inputs.shape[2],
+ inputs.shape[3],
+ ),
+ ) # inputs_reshaped: [batch_size*num_channels, num_patches, d_model]
+
+ # inputs_reshaped: [batch_size*num_channels, num_patches, d_model]
+ inputs_reshaped = self.norm(inputs_reshaped)
+
+ # put back data to the original shape
+ inputs = torch.reshape(inputs_reshaped, inputs.shape)
+
+ else:
+ inputs = self.norm(inputs)
+
+ return inputs
+
+
+class PatchTSMixerMLP(nn.Module):
+ def __init__(self, in_features, out_features, config):
+ super().__init__()
+ num_hidden = in_features * config.expansion_factor
+ self.fc1 = nn.Linear(in_features, num_hidden)
+ self.dropout1 = nn.Dropout(config.dropout)
+ self.fc2 = nn.Linear(num_hidden, out_features)
+ self.dropout2 = nn.Dropout(config.dropout)
+
+ def forward(self, inputs: torch.Tensor):
+ """
+ Args:
+ inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
+ Input to the MLP layer.
+ Returns:
+ `torch.Tensor` of the same shape as `inputs`
+ """
+ inputs = self.dropout1(nn.functional.gelu(self.fc1(inputs)))
+ inputs = self.fc2(inputs)
+ inputs = self.dropout2(inputs)
+ return inputs
+
+
+class PatchTSMixerChannelFeatureMixerBlock(nn.Module):
+ """This module mixes the features in the channel dimension.
+
+ Args:
+ config (`PatchTSMixerConfig`):
+ Configuration.
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__()
+
+ self.norm = PatchTSMixerNormLayer(config)
+ self.gated_attn = config.gated_attn
+ self.mlp = PatchTSMixerMLP(
+ in_features=config.num_input_channels,
+ out_features=config.num_input_channels,
+ config=config,
+ )
+
+ if config.gated_attn:
+ self.gating_block = PatchTSMixerGatedAttention(
+ in_size=config.num_input_channels, out_size=config.num_input_channels
+ )
+
+ def forward(self, inputs: torch.Tensor):
+ """
+ Args:
+ inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
+ input to the MLP layer
+ Returns:
+ `torch.Tensor` of the same shape as `inputs`
+ """
+ residual = inputs
+ inputs = self.norm(inputs)
+
+ inputs = inputs.permute(0, 3, 2, 1)
+
+ if self.gated_attn:
+ inputs = self.gating_block(inputs)
+
+ inputs = self.mlp(inputs)
+
+ inputs = inputs.permute(0, 3, 2, 1)
+
+ out = inputs + residual
+ return out
+
+
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->PatchTSMixer
+class PatchTSMixerAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ is_causal: bool = False,
+ config: Optional[PatchTSMixerConfig] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights, None
+
+
+class PatchMixerBlock(nn.Module):
+ """This module mixes the patch dimension.
+
+ Args:
+ config (`PatchTSMixerConfig`):
+ Configuration.
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__()
+
+ self.norm = PatchTSMixerNormLayer(config)
+
+ self.self_attn = config.self_attn
+ self.gated_attn = config.gated_attn
+
+ self.mlp = PatchTSMixerMLP(
+ in_features=config.num_patches,
+ out_features=config.num_patches,
+ config=config,
+ )
+
+ if config.gated_attn:
+ self.gating_block = PatchTSMixerGatedAttention(in_size=config.num_patches, out_size=config.num_patches)
+
+ if config.self_attn:
+ self.self_attn_layer = PatchTSMixerAttention(
+ embed_dim=config.d_model,
+ num_heads=config.self_attn_heads,
+ dropout=config.dropout,
+ config=config,
+ )
+ self.norm_attn = PatchTSMixerNormLayer(config)
+
+ def forward(self, hidden_state):
+ """
+ Args:
+ hidden_state (`torch.Tensor`): Input tensor.
+
+ Returns:
+ `torch.Tensor`: Transformed tensor.
+ """
+ residual = hidden_state
+
+ hidden_state = self.norm(hidden_state)
+
+ if self.self_attn:
+ batch_size, n_vars, num_patches, d_model = hidden_state.shape
+ hidden_state_reshaped = hidden_state.reshape(batch_size * n_vars, num_patches, d_model)
+
+ x_attn, _, _ = self.self_attn_layer(hidden_state_reshaped, output_attentions=False)
+ x_attn = x_attn.reshape(batch_size, n_vars, num_patches, d_model)
+
+ # Transpose so that num_patches is the last dimension
+ hidden_state = hidden_state.transpose(2, 3)
+ hidden_state = self.mlp(hidden_state)
+
+ if self.gated_attn:
+ hidden_state = self.gating_block(hidden_state)
+
+ # Transpose back
+ hidden_state = hidden_state.transpose(2, 3)
+
+ if self.self_attn:
+ hidden_state = self.norm_attn(hidden_state + x_attn)
+
+ out = hidden_state + residual
+ return out
+
+
+class FeatureMixerBlock(nn.Module):
+ """This module mixes the hidden feature dimension.
+
+ Args:
+ config (`PatchTSMixerConfig`):
+ Configuration.
+
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__()
+
+ self.norm = PatchTSMixerNormLayer(config)
+
+ self.gated_attn = config.gated_attn
+
+ self.mlp = PatchTSMixerMLP(
+ in_features=config.d_model,
+ out_features=config.d_model,
+ config=config,
+ )
+
+ if config.gated_attn:
+ self.gating_block = PatchTSMixerGatedAttention(in_size=config.d_model, out_size=config.d_model)
+
+ def forward(self, hidden: torch.Tensor):
+ """
+ Args:
+ hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
+ Input tensor to the layer.
+
+ Returns:
+ `torch.Tensor`: Transformed tensor.
+ """
+ residual = hidden
+ hidden = self.norm(hidden)
+ hidden = self.mlp(hidden)
+
+ if self.gated_attn:
+ hidden = self.gating_block(hidden)
+
+ out = hidden + residual
+ return out
+
+
+class PatchTSMixerLayer(nn.Module):
+ """
+ The `PatchTSMixer` layer that does all three kinds of mixing.
+
+ Args:
+ config (`PatchTSMixerConfig`):
+ Configuration.
+
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__()
+
+ self.patch_mixer = PatchMixerBlock(config=config)
+ self.feature_mixer = FeatureMixerBlock(config=config)
+
+ self.mode = config.mode
+
+ if config.mode == "mix_channel":
+ self.channel_feature_mixer = PatchTSMixerChannelFeatureMixerBlock(config=config)
+
+ def forward(self, hidden: torch.Tensor):
+ """
+ Args:
+ hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
+ Input tensor to the layer.
+
+ Returns:
+ `torch.Tensor`: Transformed tensor.
+ """
+ if self.mode == "mix_channel":
+ hidden = self.channel_feature_mixer(hidden)
+
+ hidden = self.patch_mixer(hidden)
+ hidden = self.feature_mixer(hidden) # hidden: (batch_size x num_patches x d_model)
+ return hidden
+
+
+class PatchTSMixerBlock(nn.Module):
+ """The main computing framework of the `PatchTSMixer` model.
+
+ Args:
+ config (`PatchTSMixerConfig`):
+ Configuration.
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__()
+
+ num_layers = config.num_layers
+
+ self.mixers = nn.ModuleList([PatchTSMixerLayer(config=config) for _ in range(num_layers)])
+
+ def forward(self, hidden_state, output_hidden_states: bool = False):
+ """
+ Args:
+ hidden_state (`torch.Tensor`): The input tensor.
+ output_hidden_states (`bool`, *optional*, defaults to False.):
+ Whether to output the hidden states as well.
+
+ Returns:
+ `torch.Tensor`: The embedding. `list`: List of all hidden states if `output_hidden_states` is set to
+ `True`.
+ """
+ all_hidden_states = []
+
+ embedding = hidden_state
+
+ for mod in self.mixers:
+ embedding = mod(embedding)
+ if output_hidden_states:
+ all_hidden_states.append(embedding)
+
+ if output_hidden_states:
+ return embedding, all_hidden_states
+ else:
+ return embedding, None
+
+
+class PatchTSMixerForPredictionHead(nn.Module):
+ """Prediction Head for Forecasting
+
+ Args:
+ config (`PatchTSMixerConfig`):
+ Configuration.
+ """
+
+ def __init__(self, config: PatchTSMixerConfig, distribution_output=None):
+ super().__init__()
+
+ self.prediction_channel_indices = config.prediction_channel_indices
+
+ if self.prediction_channel_indices is not None:
+ self.prediction_channel_indices.sort()
+
+ self.dropout_layer = nn.Dropout(config.head_dropout)
+ if distribution_output is None:
+ self.base_forecast_block = nn.Linear((config.num_patches * config.d_model), config.prediction_length)
+ else:
+ self.base_forecast_block = distribution_output.get_parameter_projection(
+ config.num_patches * config.d_model
+ )
+
+ self.flatten = nn.Flatten(start_dim=-2)
+
+ def forward(self, hidden_features):
+ """
+
+ Args:
+ hidden_features (`torch.Tensor` of shape `(batch_size, num_patch, d_model)` in `flatten` mode
+ or `(batch_size, n_vars, num_patch, d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
+ features.
+
+ Returns:
+ `torch.Tensor` of shape `(batch_size, prediction_length, nvars)`.
+
+ """
+
+ hidden_features = self.flatten(hidden_features) # [batch_size x n_vars x num_patch * d_model]
+ hidden_features = self.dropout_layer(hidden_features) # [batch_size x n_vars x num_patch * d_model]
+ forecast = self.base_forecast_block(hidden_features) # [batch_size x n_vars x prediction_length]
+ if isinstance(forecast, tuple):
+ forecast = tuple(z.transpose(-1, -2) for z in forecast)
+ else:
+ forecast = forecast.transpose(-1, -2) # [batch_size x prediction_length x n_vars]
+
+ if self.prediction_channel_indices is not None:
+ if isinstance(forecast, tuple):
+ forecast = tuple(z[..., self.prediction_channel_indices] for z in forecast)
+ else:
+ forecast = forecast[..., self.prediction_channel_indices] # [batch_size x prediction_length x n_vars]
+
+ return forecast
+
+
+class PatchTSMixerLinearHead(nn.Module):
+ """Linear head for Classification and Regression.
+
+ Args:
+ config (`PatchTSMixerConfig`):
+ Configuration.
+ """
+
+ def __init__(self, config: PatchTSMixerConfig, distribution_output=None):
+ super().__init__()
+
+ self.head_aggregation = config.head_aggregation
+ self.output_range = config.output_range
+
+ if config.head_aggregation is None:
+ mul_factor = config.num_patches
+ else:
+ mul_factor = 1
+ self.distribution_output = distribution_output
+ if distribution_output is None:
+ self.projection = nn.Linear(
+ config.d_model * config.num_input_channels * mul_factor,
+ config.num_targets,
+ )
+ else:
+ self.projection = distribution_output.get_parameter_projection(
+ config.d_model * config.num_input_channels * mul_factor
+ )
+
+ if config.head_aggregation is None:
+ self.flatten = nn.Flatten(start_dim=-3)
+ else:
+ self.flatten = nn.Flatten(start_dim=-2)
+
+ self.dropout = nn.Dropout(config.head_dropout)
+
+ def forward(self, hidden_features):
+ """
+ Args:
+ hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
+ or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
+ features.
+
+ Returns:
+ `torch.Tensor` of shape `(batch_size x num_targets)`.
+ """
+
+ # batch_size x d_model x num_patch or batch_size x n_vars x d_model x num_patch
+ hidden_features = hidden_features.transpose(-1, -2)
+ if self.head_aggregation == "use_last":
+ # batch_size x d_model (flatten) or # batch_size x n_vars x d_model (common_channel)
+ hidden_features = hidden_features[..., -1]
+ elif self.head_aggregation == "max_pool":
+ # batch_size x n_vars x d_model or batch_size x d_model
+ hidden_features = hidden_features.max(dim=-1).values
+ elif self.head_aggregation == "avg_pool":
+ # batch_size x n_vars x d_model or batch_size x d_model
+ hidden_features = hidden_features.mean(dim=-1)
+
+ if self.flatten:
+ hidden_features = self.flatten(hidden_features)
+ hidden_features = self.dropout(hidden_features)
+ hidden_features = self.projection(hidden_features) # batch_size x num_targets
+
+ if (self.distribution_output is None) and (self.output_range is not None):
+ hidden_features = (
+ torch.sigmoid(hidden_features) * (self.output_range[1] - self.output_range[0]) + self.output_range[0]
+ )
+ return hidden_features
+
+
+@auto_docstring
+class PatchTSMixerPreTrainedModel(PreTrainedModel):
+ # Weight initialization
+ config: PatchTSMixerConfig
+ base_model_prefix = "model"
+ main_input_name = "past_values"
+ supports_gradient_checkpointing = False
+
+ def _init_weights(self, module):
+ """Initialize weights"""
+ if isinstance(module, PatchTSMixerPositionalEncoding):
+ # initialize positional encoding
+ if self.config.positional_encoding_type == "random":
+ nn.init.normal_(module.position_enc, mean=0.0, std=0.1)
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, PatchTSMixerBatchNorm):
+ module.batchnorm.bias.data.zero_()
+ module.batchnorm.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.init_std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+class PatchTSMixerPretrainHead(nn.Module):
+ """Pretraining head.
+
+ Args:
+ config (`PatchTSMixerConfig`):
+ Configuration.
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__()
+
+ self.dropout_layer = nn.Dropout(config.head_dropout)
+ self.base_pt_block = nn.Linear(config.d_model, config.patch_length)
+
+ def forward(self, hidden_features):
+ """
+ Args:
+ hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
+ or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
+ features.
+
+ Returns:
+ `torch.Tensor` of shape `(batch_size x n_vars x num_patch x patch_length)`.
+ """
+
+ hidden_features = self.dropout_layer(hidden_features)
+ forecast = self.base_pt_block(hidden_features) # [batch_size x n_vars x num_patch x patch_length]
+ return forecast
+
+
+# Copied from transformers.models.patchtst.modeling_patchtst.random_masking
+def random_masking(
+ inputs: torch.Tensor,
+ mask_ratio: float,
+ unmasked_channel_indices: Optional[list] = None,
+ channel_consistent_masking: bool = False,
+ mask_value: int = 0,
+):
+ """random_masking: Mask the input considering the control variables.
+
+ Args:
+ inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
+ The input tensor to mask.
+ mask_ratio (`float`):
+ Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
+ unmasked_channel_indices (list, *optional*):
+ Indices of channels that will not be masked.
+ channel_consistent_masking (bool, *optional*, defaults to `False`):
+ When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
+ across channels.
+ mask_value (int, *optional*, defaults to 0):
+ Define the value of masked patches for pretraining.
+
+ Returns:
+ `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
+ n]
+ """
+ if mask_ratio < 0 or mask_ratio >= 1:
+ raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.")
+
+ batch_size, num_channels, sequence_length, num_features = inputs.shape
+ device = inputs.device
+
+ len_keep = int(sequence_length * (1 - mask_ratio))
+
+ if channel_consistent_masking:
+ noise = torch.rand(batch_size, 1, sequence_length, device=device) # noise in [0, 1], bs x 1 x L
+ noise = noise.repeat(1, num_channels, 1) # bs x num_channels x time
+ else:
+ # noise in [0, 1], bs x num_channels x L
+ noise = torch.rand(batch_size, num_channels, sequence_length, device=device)
+
+ # mask: [bs x num_channels x num_patch]
+ mask = torch.ones(batch_size, num_channels, sequence_length, device=device)
+ mask[:, :, :len_keep] = 0
+
+ # sort noise for each sample
+ ids_shuffle = torch.argsort(noise, dim=-1) # ascend: small is keep, large is remove
+ ids_restore = torch.argsort(ids_shuffle, dim=-1) # ids_restore: [bs x num_channels x L]
+
+ mask = torch.gather(mask, dim=-1, index=ids_restore)
+ mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patches x patch_length]
+ if unmasked_channel_indices is not None:
+ mask[:, unmasked_channel_indices, :, :] = 0
+
+ inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
+ return inputs_mask, mask[..., 0]
+
+
+# Copied from transformers.models.patchtst.modeling_patchtst.forecast_masking
+def forecast_masking(
+ inputs: torch.Tensor,
+ num_forecast_mask_patches: Union[list, int],
+ unmasked_channel_indices: Optional[list] = None,
+ mask_value: int = 0,
+):
+ """Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
+ If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.
+
+ Parameters:
+ inputs (`torch.Tensor`):
+ Input of shape `(bs, num_channels, num_patch, patch_length)`
+ num_forecast_mask_patches (`list`):
+ Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
+ unmasked_channel_indices (`list`, *optional*):
+ Indices of channels that are not masked.
+ mask_value (`int`, *optional*, defaults to 0):
+ Values in the masked patches will be filled by `mask_value`.
+
+ Returns:
+ `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
+ num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
+ """
+
+ if isinstance(num_forecast_mask_patches, int):
+ num_forecast_mask_patches = [num_forecast_mask_patches]
+ forecast_mask_ratios = [1 for _ in num_forecast_mask_patches]
+
+ batch_size, num_channels, sequence_length, num_features = inputs.shape
+ mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device)
+
+ t_list = []
+ total_length = 0
+ total_ratio = sum(forecast_mask_ratios)
+
+ for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios):
+ if patch_length <= 0 or patch_length >= sequence_length:
+ raise ValueError(
+ f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches."
+ )
+ temp_len = int(batch_size * ratio / total_ratio)
+ t_list.append([patch_length, ratio, temp_len])
+ total_length += temp_len
+
+ t_list = sorted(t_list, key=lambda x: x[2])
+
+ if total_length < batch_size:
+ t_list[0][2] = t_list[0][2] + (batch_size - total_length)
+ elif total_length > batch_size:
+ t_list[-1][2] = t_list[-1][2] + (total_length - batch_size)
+
+ batch1 = 0
+ for patch_len, _, temp_len in t_list:
+ batch2 = batch1 + temp_len
+ mask[batch1:batch2, :, -patch_len:] = 1
+ batch1 = batch2
+
+ perm = torch.randperm(mask.shape[0])
+ mask = mask[perm]
+
+ mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patch x patch_len]
+ if unmasked_channel_indices is not None:
+ mask[:, unmasked_channel_indices, :, :] = 0
+
+ inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
+ return inputs_mask, mask[..., 0]
+
+
+# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTPatchify with PatchTST->PatchTSMixer
+class PatchTSMixerPatchify(nn.Module):
+ """
+ A class to patchify the time series sequence into different patches
+
+ Returns:
+ `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__()
+
+ self.sequence_length = config.context_length
+ self.patch_length = config.patch_length
+ self.patch_stride = config.patch_stride
+
+ if self.sequence_length <= self.patch_length:
+ raise ValueError(
+ f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})"
+ )
+
+ # get the number of patches
+ self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1
+ new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1)
+ self.sequence_start = self.sequence_length - new_sequence_length
+
+ def forward(self, past_values: torch.Tensor):
+ """
+ Parameters:
+ past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
+ Input for patchification
+
+ Returns:
+ `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
+ """
+ sequence_length = past_values.shape[-2]
+ if sequence_length != self.sequence_length:
+ raise ValueError(
+ f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})."
+ )
+ # output: [bs x new_sequence_length x num_channels]
+ output = past_values[:, self.sequence_start :, :]
+ # output: [bs x num_patches x num_input_channels x patch_length]
+ output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride)
+ # output: [bs x num_input_channels x num_patches x patch_length]
+ output = output.transpose(-2, -3).contiguous()
+ return output
+
+
+# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTMasking with PatchTST->PatchTSMixer
+class PatchTSMixerMasking(nn.Module):
+ """
+ Class to perform random or forecast masking.
+
+ Parameters:
+ config (`PatchTSMixerConfig`): model config
+ Returns:
+ x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
+ Masked patched input
+ mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
+ Bool tensor indicating True on masked points
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__()
+ self.random_mask_ratio = config.random_mask_ratio
+ self.channel_consistent_masking = config.channel_consistent_masking
+ self.mask_type = config.mask_type
+ self.num_forecast_mask_patches = config.num_forecast_mask_patches
+ self.unmasked_channel_indices = config.unmasked_channel_indices
+ self.mask_value = config.mask_value
+ if self.unmasked_channel_indices is not None:
+ self.unmasked_channel_indices = sorted(self.unmasked_channel_indices)
+
+ def forward(self, patch_input: torch.Tensor):
+ """
+ Parameters:
+ patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
+ Patch input
+
+ Return:
+ masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
+ Masked patched input
+ mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
+ Bool tensor indicating True on masked points
+
+ """
+ if self.mask_type == "random":
+ masked_input, mask = random_masking(
+ inputs=patch_input,
+ mask_ratio=self.random_mask_ratio,
+ unmasked_channel_indices=self.unmasked_channel_indices,
+ channel_consistent_masking=self.channel_consistent_masking,
+ mask_value=self.mask_value,
+ )
+ elif self.mask_type == "forecast":
+ masked_input, mask = forecast_masking(
+ inputs=patch_input,
+ num_forecast_mask_patches=self.num_forecast_mask_patches,
+ unmasked_channel_indices=self.unmasked_channel_indices,
+ mask_value=self.mask_value,
+ )
+ else:
+ raise ValueError(f"Invalid mask type {self.mask_type}.")
+
+ # mask: [bs x num_input_channels x num_patch]
+ mask = mask.bool()
+ return masked_input, mask
+
+
+# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTStdScaler with PatchTST->PatchTSMixer
+class PatchTSMixerStdScaler(nn.Module):
+ """
+ Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
+ subtracting from the mean and dividing by the standard deviation.
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__()
+ self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
+ self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
+ self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5
+
+ def forward(
+ self, data: torch.Tensor, observed_indicator: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Parameters:
+ data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
+ input for Batch norm calculation
+ observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
+ Calculating the scale on the observed indicator.
+ Returns:
+ tuple of `torch.Tensor` of shapes
+ (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
+ `(batch_size, 1, num_input_channels)`)
+ """
+ denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
+ denominator = denominator.clamp_min(1.0)
+ loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator
+
+ variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
+ scale = torch.sqrt(variance + self.minimum_scale)
+ return (data - loc) / scale, loc, scale
+
+
+# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTMeanScaler with PatchTST->PatchTSMixer
+class PatchTSMixerMeanScaler(nn.Module):
+ """
+ Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
+ accordingly.
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__()
+ self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
+ self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
+ self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
+ self.default_scale = config.default_scale if hasattr(config, "default_scale") else None
+
+ def forward(
+ self, data: torch.Tensor, observed_indicator: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Parameters:
+ data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
+ input for Batch norm calculation
+ observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
+ Calculating the scale on the observed indicator.
+ Returns:
+ tuple of `torch.Tensor` of shapes
+ (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
+ `(batch_size, 1, num_input_channels)`)
+ """
+ ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
+ num_observed = observed_indicator.sum(self.dim, keepdim=True)
+
+ scale = ts_sum / torch.clamp(num_observed, min=1)
+
+ # If `default_scale` is provided, we use it, otherwise we use the scale
+ # of the batch.
+ if self.default_scale is None:
+ batch_sum = ts_sum.sum(dim=0)
+ batch_observations = torch.clamp(num_observed.sum(0), min=1)
+ default_scale = torch.squeeze(batch_sum / batch_observations)
+ else:
+ default_scale = self.default_scale * torch.ones_like(scale)
+
+ # apply default scale where there are no observations
+ scale = torch.where(num_observed > 0, scale, default_scale)
+
+ # ensure the scale is at least `self.minimum_scale`
+ scale = torch.clamp(scale, min=self.minimum_scale)
+ scaled_data = data / scale
+
+ if not self.keepdim:
+ scale = scale.squeeze(dim=self.dim)
+
+ return scaled_data, torch.zeros_like(scale), scale
+
+
+# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTNOPScaler with PatchTST->PatchTSMixer
+class PatchTSMixerNOPScaler(nn.Module):
+ """
+ Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__()
+ self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
+ self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
+
+ def forward(
+ self, data: torch.Tensor, observed_indicator: Optional[torch.Tensor] = None
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Parameters:
+ data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
+ input for Batch norm calculation
+ Returns:
+ tuple of `torch.Tensor` of shapes
+ (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
+ `(batch_size, 1, num_input_channels)`)
+ """
+ scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
+ loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
+ return data, loc, scale
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for `PatchTSMixerEncoderOutput`, with potential hidden states.
+ """
+)
+class PatchTSMixerEncoderOutput(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
+ Hidden-state at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Hidden-states of the model at the output of each layer.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+class PatchTSMixerEncoder(PatchTSMixerPreTrainedModel):
+ """
+ Encoder for PatchTSMixer which inputs patched time-series and outputs patched embeddings.
+
+ Args:
+ config (`PatchTSMixerConfig`):
+ Configuration.
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__(config)
+
+ self.use_return_dict = config.use_return_dict
+
+ self.patcher = nn.Linear(config.patch_length, config.d_model)
+ if config.use_positional_encoding:
+ self.positional_encoder = PatchTSMixerPositionalEncoding(config=config)
+ else:
+ self.positional_encoder = None
+ self.mlp_mixer_encoder = PatchTSMixerBlock(config=config)
+
+ # Initialize weights and apply final processing
+ if config.post_init:
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ past_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, PatchTSMixerEncoderOutput]:
+ r"""
+ past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
+ Context values of the time series. For a pretraining task, this denotes the input time series to
+ predict the masked portion. For a forecasting task, this denotes the history/past time series values.
+ Similarly, for classification or regression tasks, it denotes the appropriate context values of the
+ time series.
+
+ For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series,
+ it is greater than 1.
+
+ Returns:
+ `torch.FloatTensor` of shape `(batch_size, n_vars, num_patches, d_model)`
+ """
+
+ return_dict = return_dict if return_dict is not None else self.use_return_dict
+
+ # flatten [bs x num_patch x d_model]. common_channel/mix_channel: [bs x n_vars x num_patch x d_model]
+ patches = self.patcher(past_values)
+
+ # add positional encoder
+ if self.positional_encoder is not None:
+ patches = self.positional_encoder(patches)
+
+ last_hidden_state, hidden_states = self.mlp_mixer_encoder(patches, output_hidden_states=output_hidden_states)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ last_hidden_state,
+ hidden_states,
+ ]
+ )
+
+ return PatchTSMixerEncoderOutput(last_hidden_state=last_hidden_state, hidden_states=hidden_states)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for model's outputs, with potential hidden states.
+ """
+)
+class PatchTSMixerModelOutput(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
+ Hidden-state at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Hidden-states of the model at the output of each layer.
+ patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
+ Patched input data to the model.
+ mask (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*):
+ Bool Tensor indicating True in masked patches and False otherwise.
+ loc (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
+ Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin
+ enabled.
+ scale (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
+ Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin
+ enabled.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ patch_input: Optional[torch.FloatTensor] = None
+ mask: Optional[torch.FloatTensor] = None
+ loc: Optional[torch.FloatTensor] = None
+ scale: Optional[torch.FloatTensor] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ The PatchTSMixer Model for time-series forecasting.
+ """
+)
+class PatchTSMixerModel(PatchTSMixerPreTrainedModel):
+ def __init__(self, config: PatchTSMixerConfig, mask_input: bool = False):
+ r"""
+ mask_input (bool, *optional*, defaults to `False`):
+ Whether to mask the input using the [`PatchTSMixerMasking`] module.
+ """
+ super().__init__(config)
+
+ self.use_return_dict = config.use_return_dict
+ self.encoder = PatchTSMixerEncoder(config)
+ self.patching = PatchTSMixerPatchify(config)
+
+ if mask_input is True:
+ self.masking = PatchTSMixerMasking(config)
+ else:
+ self.masking = None
+
+ if config.scaling == "mean":
+ self.scaler = PatchTSMixerMeanScaler(config)
+ elif config.scaling == "std" or config.scaling is True:
+ self.scaler = PatchTSMixerStdScaler(config)
+ else:
+ self.scaler = PatchTSMixerNOPScaler(config)
+
+ # Initialize weights and apply final processing
+ if config.post_init:
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ past_values: torch.Tensor,
+ observed_mask: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = None,
+ ) -> PatchTSMixerModelOutput:
+ r"""
+ past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
+ Context values of the time series. For a pretraining task, this denotes the input time series to predict
+ the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
+ for classification or regression tasks, it denotes the appropriate context values of the time series.
+
+ For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
+ greater than 1.
+ observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
+ Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
+ in `[0, 1]`:
+ - 1 for values that are **observed**,
+ - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
+ """
+ return_dict = return_dict if return_dict is not None else self.use_return_dict
+
+ mask = None
+ if observed_mask is None:
+ observed_mask = torch.ones_like(past_values)
+ scaled_past_values, loc, scale = self.scaler(past_values, observed_mask)
+
+ patched_x = self.patching(scaled_past_values) # [batch_size x num_input_channels x num_patch x patch_length
+
+ enc_input = patched_x
+ if self.masking is not None:
+ enc_input, mask = self.masking(patched_x)
+ # enc_input: [batch_size x num_input_channels x num_patch x patch_length]
+ # mask: [batch_size x num_input_channels x num_patch]
+
+ encoder_output = self.encoder(
+ enc_input,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if isinstance(encoder_output, tuple):
+ encoder_output = PatchTSMixerEncoderOutput(*encoder_output)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ encoder_output.last_hidden_state,
+ encoder_output.hidden_states,
+ patched_x,
+ mask,
+ loc,
+ scale,
+ ]
+ )
+
+ return PatchTSMixerModelOutput(
+ last_hidden_state=encoder_output.last_hidden_state,
+ hidden_states=encoder_output.hidden_states,
+ patch_input=patched_x,
+ mask=mask,
+ loc=loc,
+ scale=scale,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`PatchTSMixerForPreTrainingOutput`].
+ """
+)
+class PatchTSMixerForPreTrainingOutput(ModelOutput):
+ r"""
+ loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
+ Total loss
+ prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, patch_length)`):
+ Prediction output from the pretrain head.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
+ Backbone embeddings before passing through the head.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Hidden-states of the model at the output of each layer.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ prediction_outputs: Optional[torch.FloatTensor] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ `PatchTSMixer` for mask pretraining.
+ """
+)
+class PatchTSMixerForPretraining(PatchTSMixerPreTrainedModel):
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__(config)
+ self.model = PatchTSMixerModel(config, mask_input=True)
+ self.head = PatchTSMixerPretrainHead(config=config)
+ self.masked_loss = config.masked_loss
+ self.use_return_dict = config.use_return_dict
+
+ # Initialize weights and apply final processing
+ if config.post_init:
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ past_values: torch.Tensor,
+ observed_mask: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = False,
+ return_loss: bool = True,
+ return_dict: Optional[bool] = None,
+ ) -> PatchTSMixerForPreTrainingOutput:
+ r"""
+ past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
+ Context values of the time series. For a pretraining task, this denotes the input time series to predict
+ the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
+ for classification or regression tasks, it denotes the appropriate context values of the time series.
+
+ For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
+ greater than 1.
+ observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
+ Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
+ in `[0, 1]`:
+ - 1 for values that are **observed**,
+ - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
+ return_loss (`bool`, *optional*):
+ Whether to return the loss in the `forward` call.
+ """
+ return_dict = return_dict if return_dict is not None else self.use_return_dict
+
+ if self.masked_loss is True:
+ loss = torch.nn.MSELoss(reduction="none")
+ else:
+ loss = torch.nn.MSELoss(reduction="mean")
+
+ # past_values: tensor [batch_size x context_length x num_input_channels]
+ model_output = self.model(
+ past_values,
+ observed_mask=observed_mask,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ ) # x.last_hidden_state: [batch_size x nvars x num_patch x d_model]
+ if isinstance(model_output, tuple):
+ model_output = PatchTSMixerModelOutput(*model_output)
+
+ x_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x nvars x num_patch x patch_length]
+
+ if return_loss is True:
+ loss_val = loss(x_hat, model_output.patch_input)
+ else:
+ loss_val = None
+
+ # calculate masked_loss
+ if self.masked_loss is True and loss_val is not None:
+ loss_val = (loss_val.mean(dim=-1) * model_output.mask).sum() / (model_output.mask.sum() + 1e-10)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ loss_val,
+ x_hat,
+ model_output.last_hidden_state,
+ model_output.hidden_states,
+ ]
+ )
+
+ return PatchTSMixerForPreTrainingOutput(
+ loss=loss_val,
+ prediction_outputs=x_hat, # tensor [batch_size x nvars x num_patch x patch_length]
+ last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model]
+ hidden_states=model_output.hidden_states,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`PatchTSMixerForPredictionOutput`].
+ """
+)
+class PatchTSMixerForPredictionOutput(ModelOutput):
+ r"""
+ loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
+ Total loss.
+ prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_input_channels)`):
+ Prediction output from the forecast head.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
+ Backbone embeddings before passing through the head.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ loc (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
+ Input mean
+ scale (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
+ Input std dev
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ prediction_outputs: Optional[torch.FloatTensor] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ loc: Optional[torch.FloatTensor] = None
+ scale: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for time series model's predictions outputs that contains the sampled values from the chosen
+ distribution.
+ """
+)
+class SamplePatchTSMixerPredictionOutput(ModelOutput):
+ r"""
+ sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`):
+ Sampled values from the chosen distribution.
+ """
+
+ sequences: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for time series model's predictions outputs that contains the sampled values from the chosen
+ distribution.
+ """
+)
+class SamplePatchTSMixerRegressionOutput(ModelOutput):
+ r"""
+ sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`):
+ Sampled values from the chosen distribution.
+ """
+
+ sequences: Optional[torch.FloatTensor] = None
+
+
+# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll
+def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the negative log likelihood loss from input distribution with respect to target.
+ """
+ return -input.log_prob(target)
+
+
+# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average
+def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
+ """
+ Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
+ meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
+
+ Args:
+ input_tensor (`torch.FloatTensor`):
+ Input tensor, of which the average must be computed.
+ weights (`torch.FloatTensor`, *optional*):
+ Weights tensor, of the same shape as `input_tensor`.
+ dim (`int`, *optional*):
+ The dim along which to average `input_tensor`.
+
+ Returns:
+ `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
+ """
+ if weights is not None:
+ weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
+ sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
+ return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
+ else:
+ return input_tensor.mean(dim=dim)
+
+
+class PatchTSMixerForPrediction(PatchTSMixerPreTrainedModel):
+ r"""
+ `PatchTSMixer` for forecasting application.
+
+ Args:
+ config (`PatchTSMixerConfig`):
+ Configuration.
+
+ Returns:
+ `None`.
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__(config)
+ self.loss = config.loss
+ self.use_return_dict = config.use_return_dict
+ self.prediction_channel_indices = config.prediction_channel_indices
+ self.num_parallel_samples = config.num_parallel_samples
+
+ if config.loss == "mse":
+ self.distribution_output = None
+ else:
+ dim = config.prediction_length
+ distribution_output_map = {
+ "student_t": StudentTOutput,
+ "normal": NormalOutput,
+ "negative_binomial": NegativeBinomialOutput,
+ }
+ output_class = distribution_output_map.get(config.distribution_output, None)
+ if output_class is not None:
+ self.distribution_output = output_class(dim=dim)
+ else:
+ raise ValueError(f"Unknown distribution output {config.distribution_output}")
+
+ self.model = PatchTSMixerModel(config)
+ self.head = PatchTSMixerForPredictionHead(
+ config=config,
+ distribution_output=self.distribution_output,
+ )
+
+ # Initialize weights and apply final processing
+ if config.post_init:
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ past_values: torch.Tensor,
+ observed_mask: Optional[torch.Tensor] = None,
+ future_values: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = False,
+ return_loss: bool = True,
+ return_dict: Optional[bool] = None,
+ ) -> PatchTSMixerForPredictionOutput:
+ r"""
+ past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
+ Context values of the time series. For a pretraining task, this denotes the input time series to predict
+ the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
+ for classification or regression tasks, it denotes the appropriate context values of the time series.
+
+ For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
+ greater than 1.
+ observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
+ Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
+ in `[0, 1]`:
+ - 1 for values that are **observed**,
+ - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
+ future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,:
+ `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
+ Target values of the time series, that serve as labels for the model. The `future_values` is what the
+ Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
+ required for a pretraining task.
+
+ For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
+ to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
+ pass the target data with all channels, as channel Filtering for both prediction and target will be
+ manually applied before the loss computation.
+ return_loss (`bool`, *optional*):
+ Whether to return the loss in the `forward` call.
+ """
+ if self.loss == "mse":
+ loss = nn.MSELoss(reduction="mean")
+ elif self.loss == "nll":
+ loss = nll
+ else:
+ raise ValueError("Invalid loss function: Allowed values: mse and nll")
+
+ return_dict = return_dict if return_dict is not None else self.use_return_dict
+
+ # past_values: tensor [batch_size x context_length x num_input_channels]
+ model_output = self.model(
+ past_values,
+ observed_mask=observed_mask,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ ) # model_output: [batch_size x nvars x num_patch x d_model]
+ if isinstance(model_output, tuple):
+ model_output = PatchTSMixerModelOutput(*model_output)
+
+ # tensor [batch_size x prediction_length x num_input_channels]
+ y_hat = self.head(model_output.last_hidden_state)
+
+ loss_val = None
+ if self.prediction_channel_indices is not None:
+ if self.distribution_output:
+ distribution = self.distribution_output.distribution(
+ y_hat,
+ loc=model_output.loc[..., self.prediction_channel_indices],
+ scale=model_output.scale[..., self.prediction_channel_indices],
+ )
+ if future_values is not None and return_loss is True:
+ loss_val = loss(
+ distribution,
+ future_values[..., self.prediction_channel_indices],
+ )
+ # take average of the loss
+ loss_val = weighted_average(loss_val)
+ else:
+ y_hat = (
+ y_hat * model_output.scale[..., self.prediction_channel_indices]
+ + model_output.loc[..., self.prediction_channel_indices]
+ )
+ if future_values is not None and return_loss is True:
+ loss_val = loss(y_hat, future_values[..., self.prediction_channel_indices])
+ else:
+ if self.distribution_output:
+ distribution = self.distribution_output.distribution(
+ y_hat, loc=model_output.loc, scale=model_output.scale
+ )
+ if future_values is not None and return_loss is True:
+ loss_val = loss(distribution, future_values)
+ loss_val = weighted_average(loss_val)
+ else:
+ y_hat = y_hat * model_output.scale + model_output.loc
+ if future_values is not None and return_loss is True:
+ loss_val = loss(y_hat, future_values)
+
+ if self.prediction_channel_indices is not None:
+ loc = model_output.loc[..., self.prediction_channel_indices]
+ scale = model_output.scale[..., self.prediction_channel_indices]
+ else:
+ loc = model_output.loc
+ scale = model_output.scale
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ loss_val,
+ y_hat,
+ model_output.last_hidden_state,
+ model_output.hidden_states,
+ loc,
+ scale,
+ ]
+ )
+
+ return PatchTSMixerForPredictionOutput(
+ loss=loss_val,
+ prediction_outputs=y_hat, # tensor [batch_size x prediction_length x num_input_channels]
+ last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model]
+ hidden_states=model_output.hidden_states,
+ loc=loc,
+ scale=scale,
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ past_values: torch.Tensor,
+ observed_mask: Optional[torch.Tensor] = None,
+ ) -> SamplePatchTSMixerPredictionOutput:
+ """
+ Generate sequences of sample predictions from a model with a probability distribution head.
+
+ Args:
+ past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
+ Past values of the time series that serves as context in order to predict the future.
+
+ observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
+ Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
+ in `[0, 1]`:
+
+ - 1 for values that are **observed**,
+ - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
+
+ Return:
+ [`SamplePatchTSMixerPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
+ number of samples, prediction_length, num_input_channels)`.
+ """
+ # get number of samples
+ num_parallel_samples = self.num_parallel_samples
+
+ # get model output
+ outputs = self(
+ past_values=past_values,
+ future_values=None,
+ observed_mask=observed_mask,
+ output_hidden_states=False,
+ )
+
+ # get distribution
+
+ distribution = self.distribution_output.distribution(
+ outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale
+ )
+
+ # get samples: list of [batch_size x prediction_length x num_channels]
+ samples = [distribution.sample() for _ in range(num_parallel_samples)]
+
+ # stack tensors
+ samples = torch.stack(samples, dim=1) # [batch_size x num_samples x prediction_length x num_channels]
+ return SamplePatchTSMixerPredictionOutput(sequences=samples)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`PatchTSMixerForTimeSeriesClassificationOutput`].
+ """
+)
+class PatchTSMixerForTimeSeriesClassificationOutput(ModelOutput):
+ r"""
+ loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
+ Total loss.
+ prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
+ Prediction output from the classification head.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
+ Backbone embeddings before passing through the head.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ prediction_outputs: Optional[torch.FloatTensor] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel):
+ r"""
+ `PatchTSMixer` for classification application.
+
+ Args:
+ config (`PatchTSMixerConfig`):
+ Configuration.
+
+ Returns:
+ `None`.
+ """
+
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__(config)
+
+ self.model = PatchTSMixerModel(config)
+ self.head = PatchTSMixerLinearHead(
+ config=config,
+ )
+ self.use_return_dict = config.use_return_dict
+ if config.scaling in ["std", "mean", True]:
+ self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches)
+ else:
+ self.inject_scale = None
+
+ # Initialize weights and apply final processing
+ if config.post_init:
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ past_values: torch.Tensor,
+ target_values: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = False,
+ return_loss: bool = True,
+ return_dict: Optional[bool] = None,
+ ) -> PatchTSMixerForTimeSeriesClassificationOutput:
+ r"""
+ past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
+ Context values of the time series. For a pretraining task, this denotes the input time series to predict
+ the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
+ for classification or regression tasks, it denotes the appropriate context values of the time series.
+
+ For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
+ greater than 1.
+ target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
+ `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
+ Target
+ values of the time series, that serve as labels for the model. The `target_values` is what the
+ Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
+ required for a pretraining task.
+
+ For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
+ to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
+ pass the target data with all channels, as channel Filtering for both prediction and target will be
+ manually applied before the loss computation.
+
+ For a classification task, it has a shape of `(batch_size,)`.
+
+ For a regression task, it has a shape of `(batch_size, num_targets)`.
+ return_loss (`bool`, *optional*):
+ Whether to return the loss in the `forward` call.
+ """
+
+ loss = torch.nn.CrossEntropyLoss()
+
+ return_dict = return_dict if return_dict is not None else self.use_return_dict
+
+ model_output = self.model(
+ past_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ ) # x: [batch_size x nvars x num_patch x d_model]
+ if isinstance(model_output, tuple):
+ model_output = PatchTSMixerModelOutput(*model_output)
+
+ if self.inject_scale is not None:
+ model_output.last_hidden_state = self.inject_scale(
+ model_output.last_hidden_state,
+ loc=model_output.loc,
+ scale=model_output.scale,
+ ) # x: [batch_size x nvars x num_patch x d_model]
+
+ y_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x n_labels]
+
+ if target_values is not None and return_loss is True:
+ loss_val = loss(y_hat, target_values)
+ else:
+ loss_val = None
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ loss_val,
+ y_hat,
+ model_output.last_hidden_state,
+ model_output.hidden_states,
+ ]
+ )
+
+ return PatchTSMixerForTimeSeriesClassificationOutput(
+ loss=loss_val,
+ prediction_outputs=y_hat, # tensor [batch_size x n_labels]
+ last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model]
+ hidden_states=model_output.hidden_states,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`PatchTSMixerForRegressionOutput`].
+ """
+)
+class PatchTSMixerForRegressionOutput(ModelOutput):
+ r"""
+ loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
+ Total loss.
+ regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
+ Prediction output from the regression head.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
+ Backbone embeddings before passing through the head.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ regression_outputs: Optional[torch.FloatTensor] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+class InjectScalerStatistics4D(nn.Module):
+ def __init__(self, d_model: int, num_patches: int, expansion: int = 2):
+ super().__init__()
+
+ self.inverse_trans_expansion = nn.Linear(d_model + 2, expansion * d_model)
+ self.inverse_trans_compression = nn.Linear(expansion * d_model, d_model)
+ self.map_scale_expansion = nn.Linear(2, 2 * expansion)
+ self.map_scale_compression = nn.Linear(2 * expansion, 2)
+ self.num_patches = num_patches
+
+ def forward(self, inputs: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor):
+ """
+ Args:
+ inputs (`torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`)
+ loc (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
+ scale (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
+ Returns:
+ `torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`
+ """
+
+ mean = loc.transpose(-1, -2) # [batch_size x n_channels x 1 ]
+ mean = mean.unsqueeze(-2) # [batch_size x n_channels x 1 x 1]
+ mean = mean.repeat(1, 1, self.num_patches, 1) # [batch_size x n_channels x num_patch x 1]
+
+ stdev = scale.transpose(-1, -2) # [batch_size x n_channels x 1 ]
+ stdev = stdev.unsqueeze(-2) # [batch_size x n_channels x 1 x 1]
+ stdev = stdev.repeat(1, 1, self.num_patches, 1) # [batch_size x n_channels x num_patch x 1]
+
+ concat_stats = torch.cat([mean, stdev], dim=-1) # [batch_size x n_channels x num_patch x 2]
+
+ concat_stats = self.map_scale_expansion(concat_stats) # [batch_size x n_channels x num_patch x (2*expansion)]
+ concat_stats = self.map_scale_compression(concat_stats) # [batch_size x n_channels x num_patch x 2]
+
+ inputs = torch.cat([inputs, concat_stats], dim=-1) # [batch_size x channels x num_patch x d_model+2]
+ inputs = self.inverse_trans_expansion(inputs) # [batch_size x channels x num_patch x (expansion*d_model)]
+ inputs = self.inverse_trans_compression(inputs) # [batch_size x channels x num_patch x d_model]
+
+ return inputs
+
+
+@auto_docstring(
+ custom_intro="""
+ `PatchTSMixer` for regression application.
+ """
+)
+class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
+ def __init__(self, config: PatchTSMixerConfig):
+ super().__init__(config)
+
+ self.model = PatchTSMixerModel(config)
+
+ self.loss = config.loss
+ self.distribution_output = config.distribution_output
+
+ self.use_return_dict = config.use_return_dict
+ self.num_parallel_samples = config.num_parallel_samples
+
+ if config.loss == "mse":
+ self.distribution_output = None
+ else:
+ distribution_output_map = {
+ "student_t": StudentTOutput,
+ "normal": NormalOutput,
+ "negative_binomial": NegativeBinomialOutput,
+ }
+ output_class = distribution_output_map.get(config.distribution_output)
+ if output_class is not None:
+ self.distribution_output = output_class(dim=config.num_targets)
+ else:
+ raise ValueError(f"Unknown distribution output {config.distribution_output}")
+
+ if config.scaling in ["std", "mean", True]:
+ self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches)
+ else:
+ self.inject_scale = None
+
+ self.head = PatchTSMixerLinearHead(
+ config=config,
+ distribution_output=self.distribution_output,
+ )
+
+ # Initialize weights and apply final processing
+ if config.post_init:
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ past_values: torch.Tensor,
+ target_values: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = False,
+ return_loss: bool = True,
+ return_dict: Optional[bool] = None,
+ ) -> PatchTSMixerForRegressionOutput:
+ r"""
+ past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
+ Context values of the time series. For a pretraining task, this denotes the input time series to predict
+ the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
+ for classification or regression tasks, it denotes the appropriate context values of the time series.
+
+ For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
+ greater than 1.
+ target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
+ `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
+ Target values of the time series, that serve as labels for the model. The `target_values` is what the
+ Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
+ required for a pretraining task.
+
+ For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
+ to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
+ pass the target data with all channels, as channel Filtering for both prediction and target will be
+ manually applied before the loss computation.
+
+ For a classification task, it has a shape of `(batch_size,)`.
+
+ For a regression task, it has a shape of `(batch_size, num_targets)`.
+ return_loss (`bool`, *optional*):
+ Whether to return the loss in the `forward` call.
+ """
+
+ if self.loss == "mse":
+ loss = nn.MSELoss(reduction="mean")
+ elif self.loss == "nll":
+ loss = nll
+ else:
+ raise ValueError("Invalid loss function: Allowed values: mse and nll")
+
+ return_dict = return_dict if return_dict is not None else self.use_return_dict
+ model_output = self.model(
+ past_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ ) # model_output: [batch_size x nvars x num_patch x d_model]
+ if isinstance(model_output, tuple):
+ model_output = PatchTSMixerModelOutput(*model_output)
+
+ if self.inject_scale is not None:
+ model_output.last_hidden_state = self.inject_scale(
+ model_output.last_hidden_state,
+ loc=model_output.loc,
+ scale=model_output.scale,
+ ) # x: [batch_size x nvars x num_patch x d_model]
+
+ y_hat = self.head(model_output.last_hidden_state) # [batch_size x num_targets]
+
+ if target_values is not None and return_loss is True:
+ if self.distribution_output:
+ if self.distribution_output == "negative_binomial" and torch.any(target_values < 0):
+ raise Exception("target_values cannot be negative for negative_binomial distribution.")
+ distribution = self.distribution_output.distribution(y_hat)
+ # y_hat should be a 2-tuple, each with dimension [bs, num_targets]
+ y_hat = tuple(item.view(-1, self.config.num_targets) for item in y_hat)
+ loss_val = loss(distribution, target_values)
+ # take average of the loss
+ loss_val = weighted_average(loss_val)
+ else:
+ loss_val = loss(y_hat, target_values)
+ else:
+ loss_val = None
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ loss_val,
+ y_hat,
+ model_output.last_hidden_state,
+ model_output.hidden_states,
+ ]
+ )
+
+ return PatchTSMixerForRegressionOutput(
+ loss=loss_val,
+ regression_outputs=y_hat, # tensor [batch_size x num_targets]
+ last_hidden_state=model_output.last_hidden_state, # [batch_size x nvars x num_patch x d_model]
+ hidden_states=model_output.hidden_states,
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ past_values: torch.Tensor,
+ ) -> SamplePatchTSMixerRegressionOutput:
+ """
+ Generate sequences of sample predictions from a model with a probability distribution head.
+
+ Args:
+ past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
+ Past values of the time series that serves as context in order to predict the target values.
+
+ Return:
+ [`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
+ number of samples, num_targets)`.
+ """
+ # get number of samples
+ num_parallel_samples = self.num_parallel_samples
+
+ # get model output
+ outputs = self(
+ past_values=past_values,
+ target_values=None,
+ output_hidden_states=False,
+ )
+
+ # get distribution
+ distribution = self.distribution_output.distribution(outputs.regression_outputs)
+
+ # get samples
+ samples = [
+ distribution.sample() for _ in range(num_parallel_samples)
+ ] # samples: list of [batch_size x num_targets]
+ # stack tensors
+ # [batch_size x num_samples x num_targets]
+ samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
+ return SamplePatchTSMixerRegressionOutput(sequences=samples)
+
+
+__all__ = [
+ "PatchTSMixerPreTrainedModel",
+ "PatchTSMixerModel",
+ "PatchTSMixerForPretraining",
+ "PatchTSMixerForPrediction",
+ "PatchTSMixerForTimeSeriesClassification",
+ "PatchTSMixerForRegression",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4903c400f9827239f7920b7e2c9bfddfda48eecd
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_pegasus import *
+ from .modeling_flax_pegasus import *
+ from .modeling_pegasus import *
+ from .modeling_tf_pegasus import *
+ from .tokenization_pegasus import *
+ from .tokenization_pegasus_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/configuration_pegasus.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/configuration_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c27f7d44d952d769ae362b000f87b3dd6a1cfd5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/configuration_pegasus.py
@@ -0,0 +1,164 @@
+# coding=utf-8
+# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PEGASUS model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class PegasusConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PegasusModel`]. It is used to instantiate an
+ PEGASUS model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the PEGASUS
+ [google/pegasus-large](https://huggingface.co/google/pegasus-large) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50265):
+ Vocabulary size of the PEGASUS model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`PegasusModel`] or [`TFPegasusModel`].
+ d_model (`int`, *optional*, defaults to 1024):
+ Dimensionality of the layers and the pooler layer.
+ encoder_layers (`int`, *optional*, defaults to 12):
+ Number of encoder layers.
+ decoder_layers (`int`, *optional*, defaults to 12):
+ Number of decoder layers.
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
+ for more details.
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
+ for more details.
+ scale_embedding (`bool`, *optional*, defaults to `False`):
+ Scale embeddings by diving by sqrt(d_model).
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models)
+ forced_eos_token_id (`int`, *optional*, defaults to 1):
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
+ `eos_token_id`.
+
+ Example:
+
+ ```python
+ >>> from transformers import PegasusConfig, PegasusModel
+
+ >>> # Initializing a PEGASUS google/pegasus-large style configuration
+ >>> configuration = PegasusConfig()
+
+ >>> # Initializing a model (with random weights) from the google/pegasus-large style configuration
+ >>> model = PegasusModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "pegasus"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+ def __init__(
+ self,
+ vocab_size=50265,
+ max_position_embeddings=1024,
+ encoder_layers=12,
+ encoder_ffn_dim=4096,
+ encoder_attention_heads=16,
+ decoder_layers=12,
+ decoder_ffn_dim=4096,
+ decoder_attention_heads=16,
+ encoder_layerdrop=0.0,
+ decoder_layerdrop=0.0,
+ use_cache=True,
+ is_encoder_decoder=True,
+ activation_function="gelu",
+ d_model=1024,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ init_std=0.02,
+ decoder_start_token_id=0,
+ scale_embedding=False,
+ pad_token_id=0,
+ eos_token_id=1,
+ forced_eos_token_id=1,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.d_model = d_model
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.encoder_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.decoder_layers = decoder_layers
+ self.decoder_attention_heads = decoder_attention_heads
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.encoder_layerdrop = encoder_layerdrop
+ self.decoder_layerdrop = decoder_layerdrop
+ self.use_cache = use_cache
+ self.num_hidden_layers = encoder_layers
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+ super().__init__(
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ decoder_start_token_id=decoder_start_token_id,
+ forced_eos_token_id=forced_eos_token_id,
+ **kwargs,
+ )
+
+ @property
+ def num_attention_heads(self) -> int:
+ return self.encoder_attention_heads
+
+ @property
+ def hidden_size(self) -> int:
+ return self.d_model
+
+
+__all__ = ["PegasusConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/modeling_flax_pegasus.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/modeling_flax_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddf0ae492407c17bbc678b82b3d19e34da7d8a75
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/modeling_flax_pegasus.py
@@ -0,0 +1,1532 @@
+# coding=utf-8
+# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Flax PEGASUS model."""
+
+import math
+import random
+from functools import partial
+from typing import Callable, Optional
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+from jax.random import PRNGKey
+
+from ...modeling_flax_outputs import (
+ FlaxBaseModelOutput,
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
+ FlaxCausalLMOutputWithCrossAttentions,
+ FlaxSeq2SeqLMOutput,
+ FlaxSeq2SeqModelOutput,
+)
+from ...modeling_flax_utils import (
+ ACT2FN,
+ FlaxPreTrainedModel,
+ add_start_docstrings_to_model_forward,
+ append_call_sample_docstring,
+ append_replace_return_docstrings,
+ overwrite_call_docstring,
+)
+from ...utils import add_start_docstrings, logging, replace_return_docstrings
+from .configuration_pegasus import PegasusConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/pegasus-large"
+_CONFIG_FOR_DOC = "PegasusConfig"
+
+PEGASUS_START_DOCSTRING = r"""
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`PegasusConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+PEGASUS_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+
+ If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
+ paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy.
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.max_position_embeddings - 1]`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+PEGASUS_ENCODE_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+PEGASUS_DECODE_INPUTS_DOCSTRING = r"""
+ Args:
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+ encoder_outputs (`tuple(tuple(jnp.ndarray)`):
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+
+ If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
+ paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy.
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.max_position_embeddings - 1]`.
+ past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
+def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = jnp.zeros_like(input_ids)
+ shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
+ shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
+
+ shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
+ return shifted_input_ids
+
+
+# Copied from transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions
+def create_sinusoidal_positions(n_pos, dim):
+ position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
+ sentinel = dim // 2 + dim % 2
+ out = np.zeros_like(position_enc)
+ out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
+ out[:, sentinel:] = np.cos(position_enc[:, 1::2])
+
+ return jnp.array(out)
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Pegasus
+class FlaxPegasusAttention(nn.Module):
+ config: PegasusConfig
+ embed_dim: int
+ num_heads: int
+ dropout: float = 0.0
+ causal: bool = False
+ bias: bool = True
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self) -> None:
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ dense = partial(
+ nn.Dense,
+ self.embed_dim,
+ use_bias=self.bias,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
+ self.out_proj = dense()
+
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
+
+ if self.causal:
+ self.causal_mask = make_causal_mask(
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
+ )
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
+
+ @nn.compact
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slightly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
+ def __call__(
+ self,
+ hidden_states: jnp.ndarray,
+ key_value_states: Optional[jnp.ndarray] = None,
+ attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ ) -> tuple[jnp.ndarray]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ batch_size = hidden_states.shape[0]
+
+ # get query proj
+ query_states = self.q_proj(hidden_states)
+ # get key, value proj
+ if is_cross_attention:
+ # cross_attentions
+ key_states = self.k_proj(key_value_states)
+ value_states = self.v_proj(key_value_states)
+ else:
+ # self_attention
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # handle cache prepare causal attention mask
+ if self.causal:
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
+ if self.has_variable("cache", "cached_key"):
+ mask_shift = self.variables["cache"]["cache_index"]
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_mask = lax.dynamic_slice(
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
+ )
+ else:
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+ # combine masks if needed
+ if attention_mask is not None and self.causal:
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_mask)
+ elif self.causal:
+ attention_mask = causal_mask
+ elif attention_mask is not None:
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
+ key_states, value_states, query_states, attention_mask
+ )
+
+ # Convert the boolean attention mask to an attention bias.
+ if attention_mask is not None:
+ # attention mask in the form of attention bias
+ attention_bias = lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
+ )
+ else:
+ attention_bias = None
+
+ dropout_rng = None
+ if not deterministic and self.dropout > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ bias=attention_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.dropout,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ precision=None,
+ )
+
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+ attn_output = self._merge_heads(attn_output)
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Pegasus
+class FlaxPegasusEncoderLayer(nn.Module):
+ config: PegasusConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self) -> None:
+ self.embed_dim = self.config.d_model
+ self.self_attn = FlaxPegasusAttention(
+ config=self.config,
+ embed_dim=self.embed_dim,
+ num_heads=self.config.encoder_attention_heads,
+ dropout=self.config.attention_dropout,
+ dtype=self.dtype,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+ self.activation_fn = ACT2FN[self.config.activation_function]
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
+ self.fc1 = nn.Dense(
+ self.config.encoder_ffn_dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+ self.fc2 = nn.Dense(
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
+ )
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+ def __call__(
+ self,
+ hidden_states: jnp.ndarray,
+ attention_mask: jnp.ndarray,
+ output_attentions: bool = True,
+ deterministic: bool = True,
+ ) -> tuple[jnp.ndarray]:
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection with Bart->Pegasus
+class FlaxPegasusEncoderLayerCollection(nn.Module):
+ config: PegasusConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layers = [
+ FlaxPegasusEncoderLayer(self.config, name=str(i), dtype=self.dtype)
+ for i in range(self.config.encoder_layers)
+ ]
+ self.layerdrop = self.config.encoder_layerdrop
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ dropout_probability = random.uniform(0, 1)
+ if not deterministic and (dropout_probability < self.layerdrop): # skip the layer
+ layer_outputs = (None, None)
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ output_attentions,
+ deterministic,
+ )
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ outputs = (hidden_states, all_hidden_states, all_attentions)
+
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+
+# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Pegasus
+class FlaxPegasusDecoderLayer(nn.Module):
+ config: PegasusConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self) -> None:
+ self.embed_dim = self.config.d_model
+ self.self_attn = FlaxPegasusAttention(
+ config=self.config,
+ embed_dim=self.embed_dim,
+ num_heads=self.config.decoder_attention_heads,
+ dropout=self.config.attention_dropout,
+ causal=True,
+ dtype=self.dtype,
+ )
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+ self.activation_fn = ACT2FN[self.config.activation_function]
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
+
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+ self.encoder_attn = FlaxPegasusAttention(
+ config=self.config,
+ embed_dim=self.embed_dim,
+ num_heads=self.config.decoder_attention_heads,
+ dropout=self.config.attention_dropout,
+ dtype=self.dtype,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+ self.fc1 = nn.Dense(
+ self.config.decoder_ffn_dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+ self.fc2 = nn.Dense(
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
+ )
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+ def __call__(
+ self,
+ hidden_states: jnp.ndarray,
+ attention_mask: jnp.ndarray,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ output_attentions: bool = True,
+ deterministic: bool = True,
+ ) -> tuple[jnp.ndarray]:
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
+ )
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = residual + hidden_states
+
+ # Cross-Attention Block
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+ hidden_states, cross_attn_weights = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ )
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ return outputs
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->Pegasus
+class FlaxPegasusDecoderLayerCollection(nn.Module):
+ config: PegasusConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layers = [
+ FlaxPegasusDecoderLayer(self.config, name=str(i), dtype=self.dtype)
+ for i in range(self.config.decoder_layers)
+ ]
+ self.layerdrop = self.config.decoder_layerdrop
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ dropout_probability = random.uniform(0, 1)
+ if not deterministic and (dropout_probability < self.layerdrop):
+ layer_outputs = (None, None, None)
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
+
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class FlaxPegasusEncoder(nn.Module):
+ config: PegasusConfig
+ embed_tokens: nn.Embed
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+
+ embed_dim = self.config.d_model
+ self.padding_idx = self.config.pad_token_id
+ self.max_source_positions = self.config.max_position_embeddings
+ self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
+
+ self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
+ self.layers = FlaxPegasusEncoderLayerCollection(self.config, self.dtype)
+ self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+ input_shape = input_ids.shape
+ input_ids = input_ids.reshape(-1, input_shape[-1])
+
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ # embed positions
+ embed_pos = jnp.take(self.embed_positions, position_ids, axis=0)
+ # explicitly cast the positions here, since self.embed_positions are not registered as parameters
+ embed_pos = embed_pos.astype(inputs_embeds.dtype)
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ outputs = self.layers(
+ hidden_states,
+ attention_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = outputs[0]
+ last_hidden_state = self.layer_norm(last_hidden_state)
+
+ # update the last element in `hidden_states` after applying `layernorm` above
+ hidden_states = None
+ if output_hidden_states:
+ hidden_states = outputs[1]
+ hidden_states = hidden_states[:-1] + (last_hidden_state,)
+
+ if not return_dict:
+ outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutput(
+ last_hidden_state=last_hidden_state,
+ hidden_states=hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class FlaxPegasusDecoder(nn.Module):
+ config: PegasusConfig
+ embed_tokens: nn.Embed
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+
+ embed_dim = self.config.d_model
+ self.padding_idx = self.config.pad_token_id
+ self.max_target_positions = self.config.max_position_embeddings
+ self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
+
+ self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
+
+ self.layers = FlaxPegasusDecoderLayerCollection(self.config, self.dtype)
+ self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+ input_shape = input_ids.shape
+ input_ids = input_ids.reshape(-1, input_shape[-1])
+
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ # embed positions
+ positions = jnp.take(self.embed_positions, position_ids, axis=0)
+ # explicitly cast the positions here, since self.embed_positions are not registered as parameters
+ positions = positions.astype(inputs_embeds.dtype)
+
+ hidden_states = inputs_embeds + positions
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ outputs = self.layers(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = outputs[0]
+ last_hidden_state = self.layer_norm(last_hidden_state)
+
+ # update the last element in `hidden_states` after applying `layernorm` above
+ hidden_states = None
+ if output_hidden_states:
+ hidden_states = outputs[1]
+ hidden_states = hidden_states[:-1] + (last_hidden_state,)
+
+ if not return_dict:
+ outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=last_hidden_state,
+ hidden_states=hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModule with Bart->Pegasus
+class FlaxPegasusModule(nn.Module):
+ config: PegasusConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.shared = nn.Embed(
+ self.config.vocab_size,
+ self.config.d_model,
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
+ dtype=self.dtype,
+ )
+
+ self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
+ self.decoder = FlaxPegasusDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
+
+ def _get_encoder_module(self):
+ return self.encoder
+
+ def _get_decoder_module(self):
+ return self.decoder
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ position_ids,
+ decoder_position_ids,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ position_ids=decoder_position_ids,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return FlaxSeq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
+ config_class = PegasusConfig
+ base_model_prefix: str = "model"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: PegasusConfig,
+ input_shape: tuple[int] = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+ decoder_input_ids = input_ids
+ decoder_attention_mask = jnp.ones_like(input_ids)
+
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+ decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ random_params = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ position_ids,
+ decoder_position_ids,
+ )["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ def init_cache(self, batch_size, max_length, encoder_outputs):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
+ cross-attention of the decoder.
+ """
+ # init input variables to retrieve cache
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
+ )
+
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
+ decoder_module = module._get_decoder_module()
+ return decoder_module(
+ decoder_input_ids,
+ decoder_attention_mask,
+ decoder_position_ids,
+ **kwargs,
+ )
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0),
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ decoder_position_ids=decoder_position_ids,
+ encoder_hidden_states=encoder_outputs[0],
+ init_cache=True,
+ method=_decoder_forward, # we only need to call the decoder to init the cache
+ )
+ return unfreeze(init_variables["cache"])
+
+ @add_start_docstrings(PEGASUS_ENCODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=PegasusConfig)
+ def encode(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: Optional[dict] = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+ >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+
+ >>> text = "My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer(text, max_length=1024, return_tensors="np")
+ >>> encoder_outputs = model.encode(**inputs)
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+ if position_ids is None:
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
+ encode_module = module._get_encoder_module()
+ return encode_module(input_ids, attention_mask, position_ids, **kwargs)
+
+ return self.module.apply(
+ {"params": params or self.params},
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ method=_encoder_forward,
+ )
+
+ @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=PegasusConfig)
+ def decode(
+ self,
+ decoder_input_ids,
+ encoder_outputs,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_position_ids: Optional[jnp.ndarray] = None,
+ past_key_values: Optional[dict] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: Optional[dict] = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import jax.numpy as jnp
+ >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+ >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+
+ >>> text = "My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer(text, max_length=1024, return_tensors="np")
+ >>> encoder_outputs = model.encode(**inputs)
+
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
+
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
+ >>> last_decoder_hidden_states = outputs.last_hidden_state
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ encoder_hidden_states = encoder_outputs[0]
+ if encoder_attention_mask is None:
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ batch_size, sequence_length = decoder_input_ids.shape
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ if decoder_position_ids is None:
+ if past_key_values is not None:
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
+
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+ )
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
+ # it can be changed by FlaxPegasusAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
+ decoder_module = module._get_decoder_module()
+ return decoder_module(
+ decoder_input_ids,
+ decoder_attention_mask,
+ decoder_position_ids,
+ **kwargs,
+ )
+
+ outputs = self.module.apply(
+ inputs,
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ mutable=mutable,
+ method=_decoder_forward,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past = outputs
+ outputs["past_key_values"] = unfreeze(past["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past = outputs
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
+
+ return outputs
+
+ @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
+ def __call__(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ decoder_input_ids: Optional[jnp.ndarray] = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ decoder_position_ids: Optional[jnp.ndarray] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: Optional[dict] = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ # prepare encoder inputs
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+ if position_ids is None:
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ # prepare decoder inputs
+ if decoder_input_ids is None:
+ decoder_input_ids = shift_tokens_right(
+ input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
+ )
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+ if decoder_position_ids is None:
+ batch_size, sequence_length = decoder_input_ids.shape
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+ )
+
+ # Handle any PRNG if needed
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
+
+ return self.module.apply(
+ {"params": params or self.params},
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ )
+
+
+@add_start_docstrings(
+ "The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.",
+ PEGASUS_START_DOCSTRING,
+)
+class FlaxPegasusModel(FlaxPegasusPreTrainedModel):
+ config: PegasusConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ module_class = FlaxPegasusModule
+
+
+append_call_sample_docstring(FlaxPegasusModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule with Bart->Pegasus
+class FlaxPegasusForConditionalGenerationModule(nn.Module):
+ config: PegasusConfig
+ dtype: jnp.dtype = jnp.float32
+ bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
+
+ def setup(self):
+ self.model = FlaxPegasusModule(config=self.config, dtype=self.dtype)
+ self.lm_head = nn.Dense(
+ self.model.shared.num_embeddings,
+ use_bias=False,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+ self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
+
+ def _get_encoder_module(self):
+ return self.model.encoder
+
+ def _get_decoder_module(self):
+ return self.model.decoder
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ position_ids,
+ decoder_position_ids,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ position_ids=position_ids,
+ decoder_position_ids=decoder_position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ hidden_states = outputs[0]
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.model.variables["params"]["shared"]["embedding"]
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
+ else:
+ lm_logits = self.lm_head(hidden_states)
+
+ lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return output
+
+ return FlaxSeq2SeqLMOutput(
+ logits=lm_logits,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+
+@add_start_docstrings(
+ "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING
+)
+class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel):
+ module_class = FlaxPegasusForConditionalGenerationModule
+ dtype: jnp.dtype = jnp.float32
+
+ @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=PegasusConfig)
+ def decode(
+ self,
+ decoder_input_ids,
+ encoder_outputs,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_position_ids: Optional[jnp.ndarray] = None,
+ past_key_values: Optional[dict] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ deterministic: bool = True,
+ params: Optional[dict] = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import jax.numpy as jnp
+ >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+ >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+
+ >>> text = "My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer(text, max_length=1024, return_tensors="np")
+ >>> encoder_outputs = model.encode(**inputs)
+
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
+
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
+ >>> logits = outputs.logits
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ encoder_hidden_states = encoder_outputs[0]
+ if encoder_attention_mask is None:
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ batch_size, sequence_length = decoder_input_ids.shape
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ if decoder_position_ids is None:
+ if past_key_values is not None:
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
+
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+ )
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
+ # it can be changed by FlaxPegasusAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
+ decoder_module = module._get_decoder_module()
+ outputs = decoder_module(
+ decoder_input_ids,
+ decoder_attention_mask,
+ decoder_position_ids,
+ **kwargs,
+ )
+ hidden_states = outputs[0]
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = module.model.variables["params"]["shared"]["embedding"]
+ lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
+ else:
+ lm_logits = module.lm_head(hidden_states)
+
+ lm_logits += module.final_logits_bias.astype(self.dtype)
+ return lm_logits, outputs
+
+ outputs = self.module.apply(
+ inputs,
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ rngs=rngs,
+ mutable=mutable,
+ method=_decoder_forward,
+ )
+
+ if past_key_values is None:
+ lm_logits, decoder_outputs = outputs
+ else:
+ (lm_logits, decoder_outputs), past = outputs
+
+ if return_dict:
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
+ logits=lm_logits,
+ hidden_states=decoder_outputs.hidden_states,
+ attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ )
+ else:
+ outputs = (lm_logits,) + decoder_outputs[1:]
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs["past_key_values"] = unfreeze(past["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
+
+ return outputs
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ max_length,
+ attention_mask: Optional[jax.Array] = None,
+ decoder_attention_mask: Optional[jax.Array] = None,
+ encoder_outputs=None,
+ **kwargs,
+ ):
+ # initializing the cache
+ batch_size, seq_length = decoder_input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since the decoder uses a causal mask, those positions are masked anyways.
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if decoder_attention_mask is not None:
+ position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "encoder_outputs": encoder_outputs,
+ "encoder_attention_mask": attention_mask,
+ "decoder_attention_mask": extended_attention_mask,
+ "decoder_position_ids": position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
+ return model_kwargs
+
+
+FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING = """
+ Returns:
+
+ Summarization example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+ >>> model = FlaxPegasusForConditionalGeneration.from_pretrained('google/pegasus-large')
+ >>> tokenizer = AutoTokenizer.from_pretrained('google/pegasus-large')
+
+ >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np')
+
+ >>> # Generate Summary
+ >>> summary_ids = model.generate(inputs['input_ids']).sequences
+ >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
+ ```
+
+ Mask filling example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+ >>> TXT = "My friends are but they eat too many carbs."
+
+ >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
+ >>> input_ids = tokenizer([TXT], return_tensors="np")["input_ids"]
+ >>> logits = model(input_ids).logits
+
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
+ >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
+ >>> values, predictions = jax.lax.top_k(probs)
+
+ >>> tokenizer.decode(predictions).split()
+ ```
+"""
+
+overwrite_call_docstring(
+ FlaxPegasusForConditionalGeneration, PEGASUS_INPUTS_DOCSTRING + FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING
+)
+append_replace_return_docstrings(
+ FlaxPegasusForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
+)
+
+
+__all__ = ["FlaxPegasusForConditionalGeneration", "FlaxPegasusModel", "FlaxPegasusPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/modeling_pegasus.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/modeling_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc3a8005acac814393789cfa6146f413af2863bd
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/modeling_pegasus.py
@@ -0,0 +1,1671 @@
+# coding=utf-8
+# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch PEGASUS model."""
+
+import copy
+import math
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import (
+ AttentionMaskConverter,
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ auto_docstring,
+ is_torch_flex_attn_available,
+ is_torchdynamo_compiling,
+ logging,
+)
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_pegasus import PegasusConfig
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus
+class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
+ """This module produces sinusoidal positional embeddings of any length."""
+
+ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
+ super().__init__(num_positions, embedding_dim)
+
+ def _init_weight(self):
+ """
+ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
+ the 2nd half of the vector. [dim // 2:]
+ """
+ n_pos, dim = self.weight.shape
+ position_enc = np.array(
+ [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
+ )
+ out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False)
+ sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
+ out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
+ out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
+ self.weight = nn.Parameter(out, requires_grad=False)
+
+ @torch.no_grad()
+ def forward(
+ self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+ if position_ids is None:
+ bsz, seq_len = input_ids_shape[:2]
+ position_ids = torch.arange(
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
+ )
+ return super().forward(position_ids)
+
+
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Pegasus
+class PegasusAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ is_causal: bool = False,
+ config: Optional[PegasusConfig] = None,
+ layer_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+ self.layer_idx = layer_idx
+ if layer_idx is None and self.is_decoder:
+ logger.warning_once(
+ f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
+ "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
+ is_updated = False
+ if past_key_values is not None:
+ if isinstance(past_key_values, EncoderDecoderCache):
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = past_key_values.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_values.self_attention_cache
+ else:
+ curr_past_key_value = past_key_values
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_values is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
+ value_states = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_states = self.k_proj(current_states)
+ value_states = self.v_proj(current_states)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
+
+ if past_key_values is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
+ past_key_values.is_updated[self.layer_idx] = True
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS
+class PegasusEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: PegasusConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = PegasusAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ dropout=config.attention_dropout,
+ config=config,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_head_mask: torch.Tensor,
+ output_attentions: bool = False,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ if hidden_states.dtype == torch.float16:
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ return hidden_states, attn_weights
+
+
+# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS
+class PegasusDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: PegasusConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = PegasusAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ is_causal=True,
+ config=config,
+ layer_idx=layer_idx,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.encoder_attn = PegasusAttention(
+ self.embed_dim,
+ config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ config=config,
+ layer_idx=layer_idx,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+ size `(decoder_attention_heads,)`.
+ past_key_values (`Cache`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
+ cache in the correct position and to infer the complete sequence length.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Cross-Attention Block
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ hidden_states, cross_attn_weights = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ return outputs
+
+
+@auto_docstring
+class PegasusPreTrainedModel(PreTrainedModel):
+ config: PegasusConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, PegasusSinusoidalPositionalEmbedding):
+ module._init_weight()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ ):
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
+ return attention_mask
+
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
+
+class PegasusEncoder(PegasusPreTrainedModel):
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`PegasusEncoderLayer`].
+
+ Args:
+ config: PegasusConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+ self.layerdrop = config.encoder_layerdrop
+
+ embed_dim = config.d_model
+ self.padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_position_embeddings
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+
+ if embed_tokens is not None:
+ self.embed_tokens = embed_tokens
+ else:
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
+
+ self.embed_positions = PegasusSinusoidalPositionalEmbedding(
+ config.max_position_embeddings,
+ embed_dim,
+ self.padding_idx,
+ )
+ self.layers = nn.ModuleList([PegasusEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self.layer_norm = nn.LayerNorm(config.d_model)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+ config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embeddings. If position embeddings are learned, increasing the size will add
+ newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+ position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+ add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+ will remove vectors from the end.
+ """
+ logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
+ self.config.max_position_embeddings = new_num_position_embeddings
+
+ self.embed_positions = PegasusSinusoidalPositionalEmbedding(
+ self.config.max_position_embeddings,
+ self.config.d_model,
+ self.padding_idx,
+ )
+ self.embed_positions._init_weight()
+ self.embed_positions.to(self.device)
+
+ def get_position_embeddings(self) -> nn.Embedding:
+ """
+ Returns the position embeddings matrix
+ """
+ return self.embed_positions
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ embed_pos = self.embed_positions(input_shape)
+
+ hidden_states = inputs_embeds + embed_pos
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ inputs_embeds,
+ )
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ to_drop = False
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop: # skip the layer
+ to_drop = True
+
+ if to_drop:
+ layer_outputs = (None, None)
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class PegasusDecoder(PegasusPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PegasusDecoderLayer`]
+
+ Args:
+ config: PegasusConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.decoder_layerdrop
+ self.padding_idx = config.pad_token_id
+ self.max_target_positions = config.max_position_embeddings
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+
+ if embed_tokens is not None:
+ self.embed_tokens = embed_tokens
+ else:
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
+
+ self.embed_positions = PegasusSinusoidalPositionalEmbedding(
+ config.max_position_embeddings,
+ config.d_model,
+ self.padding_idx,
+ )
+ self.layers = nn.ModuleList([PegasusDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
+ self.layer_norm = nn.LayerNorm(config.d_model)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+ config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embeddings. If position embeddings are learned, increasing the size will add
+ newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+ position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+ add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+ will remove vectors from the end.
+ """
+ logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
+ self.config.max_position_embeddings = new_num_position_embeddings
+
+ self.embed_positions = PegasusSinusoidalPositionalEmbedding(
+ self.config.max_position_embeddings,
+ self.config.d_model,
+ self.padding_idx,
+ )
+ self.embed_positions._init_weight()
+ self.embed_positions.to(self.device)
+
+ def get_position_embeddings(self) -> nn.Embedding:
+ """
+ Returns the position embeddings matrix
+ """
+ return self.embed_positions
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ inputs_embeds=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ cache_position=None,
+ ):
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in decoder to avoid performing
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
+ cache in the correct position and to infer the complete sequence length.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input)
+
+ # important to apply scale outside of `if` in case users pass `embeds`
+ inputs_embeds = inputs_embeds * self.embed_scale
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # initialize `past_key_values`
+ if use_cache and past_key_values is None:
+ past_key_values = (
+ EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+ if encoder_hidden_states is not None
+ else DynamicCache(config=self.config)
+ )
+ if use_cache and isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
+
+ batch_size, seq_length = inputs_embeds.size()[:-1]
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
+ )
+
+ if attention_mask is None and not is_torchdynamo_compiling():
+ # required mask seq length can be calculated via length of past cache
+ mask_seq_length = past_key_values_length + seq_length
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
+
+ self_attn_cache = (
+ past_key_values.self_attention_cache
+ if isinstance(past_key_values, EncoderDecoderCache)
+ else past_key_values
+ )
+
+ causal_mask = self._update_causal_mask(
+ attention_mask,
+ inputs_embeds,
+ cache_position,
+ self_attn_cache,
+ )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
+ )
+
+ # embed positions
+ positions = self.embed_positions((batch_size, seq_length), past_key_values_length, position_ids=cache_position)
+ hidden_states = inputs_embeds + positions
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ causal_mask,
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@auto_docstring
+class PegasusModel(PegasusPreTrainedModel):
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
+
+ def __init__(self, config: PegasusConfig):
+ super().__init__(config)
+
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
+
+ self.encoder = PegasusEncoder(config, self.shared)
+ self.decoder = PegasusDecoder(config, self.shared)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, value):
+ self.shared = value
+ self.encoder.embed_tokens = self.shared
+ self.decoder.embed_tokens = self.shared
+
+ def get_encoder(self):
+ return self.encoder
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+ config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embeddings. If position embeddings are learned, increasing the size will add
+ newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+ position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+ add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+ will remove vectors from the end.
+ """
+ self.config.max_position_embeddings = new_num_position_embeddings
+ self.encoder.resize_position_embeddings(new_num_position_embeddings)
+ self.decoder.resize_position_embeddings(new_num_position_embeddings)
+
+ def get_position_embeddings(self) -> tuple[nn.Embedding]:
+ """
+ Returns the position embeddings matrix
+ """
+ return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings())
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.Tensor] = None,
+ decoder_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[tuple[torch.FloatTensor]] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, Seq2SeqModelOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
+ 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, PegasusModel
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+ >>> model = PegasusModel.from_pretrained("google/pegasus-large")
+
+ >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt")
+ >>> decoder_inputs = tokenizer("Studies show that", return_tensors="pt")
+ >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_inputs.input_ids)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ >>> list(last_hidden_states.shape)
+ [1, 4, 1024]
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The PEGASUS Model with a language modeling head. Can be used for summarization.
+ """
+)
+class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin):
+ base_model_prefix = "model"
+ _keys_to_ignore_on_load_missing = ["final_logits_bias"]
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
+
+ def __init__(self, config: PegasusConfig):
+ super().__init__(config)
+ self.model = PegasusModel(config)
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.model.get_encoder()
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def resize_token_embeddings(
+ self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
+ ) -> nn.Embedding:
+ new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
+ self._resize_final_logits_bias(new_embeddings.weight.shape[0])
+ return new_embeddings
+
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
+ old_num_tokens = self.final_logits_bias.shape[-1]
+ if new_num_tokens <= old_num_tokens:
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
+ else:
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
+ self.register_buffer("final_logits_bias", new_bias)
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+ config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embeddings. If position embeddings are learned, increasing the size will add
+ newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+ position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+ add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+ will remove vectors from the end.
+ """
+ self.config.max_position_embeddings = new_num_position_embeddings
+ self.model.encoder.resize_position_embeddings(new_num_position_embeddings)
+ self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
+
+ def get_position_embeddings(self) -> tuple[nn.Embedding]:
+ """
+ Returns the position embeddings matrix
+ """
+ return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings())
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.Tensor] = None,
+ decoder_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[tuple[torch.FloatTensor]] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, Seq2SeqLMOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
+ 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example Summarization:
+
+ ```python
+ >>> from transformers import AutoTokenizer, PegasusForConditionalGeneration
+
+ >>> model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum")
+
+ >>> ARTICLE_TO_SUMMARIZE = (
+ ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
+ ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
+ ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
+ ... )
+ >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="pt")
+
+ >>> # Generate Summary
+ >>> summary_ids = model.generate(inputs["input_ids"])
+ >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "California's largest electricity provider has turned off power to hundreds of thousands of customers."
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None:
+ if use_cache:
+ logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
+ use_cache = False
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=masked_lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus
+class PegasusDecoderWrapper(PegasusPreTrainedModel):
+ """
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
+ used in combination with the [`EncoderDecoderModel`] framework.
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.decoder = PegasusDecoder(config)
+
+ def forward(self, *args, **kwargs):
+ return self.decoder(*args, **kwargs)
+
+
+class PegasusForCausalLM(PegasusPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ config = copy.deepcopy(config)
+ config.is_decoder = True
+ config.is_encoder_decoder = False
+ super().__init__(config)
+ self.model = PegasusDecoderWrapper(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
+
+ def set_decoder(self, decoder):
+ self.model.decoder = decoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ def get_position_embeddings(self) -> nn.Embedding:
+ """
+ Returns the position embeddings matrix
+ """
+ return self.model.decoder.get_position_embeddings()
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+ config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embeddings. If position embeddings are learned, increasing the size will add
+ newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+ position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+ add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+ will remove vectors from the end.
+ """
+ self.config.max_position_embeddings = new_num_position_embeddings
+ self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
+
+ @auto_docstring
+ # Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus, facebook/bart-base->google/pegasus-large
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, PegasusForCausalLM
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+ >>> model = PegasusForCausalLM.from_pretrained("google/pegasus-large", add_cross_attention=False)
+ >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> logits = outputs.logits
+ >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
+ >>> list(logits.shape) == expected_shape
+ True
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ logits = self.lm_head(outputs[0])
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+__all__ = ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/modeling_tf_pegasus.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/modeling_tf_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..d159fc00138dcb6ca2f14213857b6b21e97c6865
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/modeling_tf_pegasus.py
@@ -0,0 +1,1573 @@
+# coding=utf-8
+# Copyright 2021, Google Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 Pegasus model."""
+
+from __future__ import annotations
+
+import random
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutput,
+ TFBaseModelOutputWithPastAndCrossAttentions,
+ TFSeq2SeqLMOutput,
+ TFSeq2SeqModelOutput,
+)
+
+# Public API
+from ...modeling_tf_utils import (
+ TFCausalLanguageModelingLoss,
+ TFModelInputType,
+ TFPreTrainedModel,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
+from ...utils import (
+ add_code_sample_docstrings,
+ add_end_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_pegasus import PegasusConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/pegasus-large"
+_CONFIG_FOR_DOC = "PegasusConfig"
+
+
+LARGE_NEGATIVE = -1e8
+
+
+# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
+def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
+ decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
+ start_tokens = tf.fill(
+ (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
+ )
+ shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids = tf.where(
+ shifted_input_ids == -100,
+ tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
+ shifted_input_ids,
+ )
+
+ # "Verify that `labels` has only positive values and -100"
+ assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
+
+ # Make sure the assertion op is called by wrapping the result in an identity no-op
+ with tf.control_dependencies([assert_gte0]):
+ shifted_input_ids = tf.identity(shifted_input_ids)
+
+ return shifted_input_ids
+
+
+# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask
+def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz = input_ids_shape[0]
+ tgt_len = input_ids_shape[1]
+ mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
+ mask_cond = tf.range(shape_list(mask)[-1])
+
+ mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
+
+ if past_key_values_length > 0:
+ mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
+
+ return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
+
+
+# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
+def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ src_len = shape_list(mask)[1]
+ tgt_len = tgt_len if tgt_len is not None else src_len
+ one_cst = tf.constant(1.0)
+ mask = tf.cast(mask, dtype=one_cst.dtype)
+ expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
+
+ return (one_cst - expanded_mask) * LARGE_NEGATIVE
+
+
+# Copied from transformers.models.marian.modeling_tf_marian.TFMarianSinusoidalPositionalEmbedding with Marian->Pegasus
+class TFPegasusSinusoidalPositionalEmbedding(keras.layers.Layer):
+ """This module produces sinusoidal positional embeddings of any length."""
+
+ def __init__(self, num_positions: int, embedding_dim: int, **kwargs):
+ super().__init__(**kwargs)
+
+ if embedding_dim % 2 != 0:
+ raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
+
+ self.embedding_dim = embedding_dim
+ self.num_positions = num_positions
+
+ def build(self, input_shape: tf.TensorShape):
+ """
+ Build shared token embedding layer Shared weights logic adapted from
+ https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
+ """
+
+ weight = self._init_weight(self.num_positions, self.embedding_dim)
+
+ self.weight = self.add_weight(
+ name="embeddings",
+ shape=[self.num_positions, self.embedding_dim],
+ )
+ weight = tf.cast(weight, dtype=self.weight.dtype)
+
+ self.weight.assign(weight)
+
+ super().build(input_shape)
+
+ @staticmethod
+ def _init_weight(n_pos: int, dim: int):
+ """
+ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
+ the 2nd half of the vector. [dim // 2:]
+ """
+ position_enc = np.array(
+ [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
+ )
+ table = np.zeros_like(position_enc)
+ # index 0 is all zero
+ table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
+ table[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
+ # convert to tensor
+ table = tf.convert_to_tensor(table)
+ tf.stop_gradient(table)
+ return table
+
+ def call(
+ self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: tf.Tensor | None = None
+ ):
+ """Input is expected to be of size [bsz x seqlen]."""
+ if position_ids is None:
+ seq_len = input_shape[1]
+ position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
+ return tf.gather(self.weight, position_ids)
+
+
+# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Pegasus
+class TFPegasusAttention(keras.layers.Layer):
+ """Multi-headed attention from "Attention Is All You Need"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = keras.layers.Dropout(dropout)
+ self.head_dim = embed_dim // num_heads
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
+ self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
+ self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
+ self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
+
+ def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
+ return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ key_value_states: tf.Tensor | None = None,
+ past_key_value: tuple[tuple[tf.Tensor]] | None = None,
+ attention_mask: tf.Tensor | None = None,
+ layer_head_mask: tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> tuple[tf.Tensor, tf.Tensor | None]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ bsz, tgt_len, embed_dim = shape_list(hidden_states)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = tf.concat([past_key_value[0], key_states], axis=2)
+ value_states = tf.concat([past_key_value[1], value_states], axis=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
+ key_states = tf.reshape(key_states, proj_shape)
+ value_states = tf.reshape(value_states, proj_shape)
+
+ src_len = shape_list(key_states)[1]
+ attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
+
+ tf.debugging.assert_equal(
+ shape_list(attn_weights),
+ [bsz * self.num_heads, tgt_len, src_len],
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
+ )
+
+ if attention_mask is not None:
+ tf.debugging.assert_equal(
+ shape_list(attention_mask),
+ [bsz, 1, tgt_len, src_len],
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
+ )
+
+ attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
+ attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+ attn_weights = stable_softmax(attn_weights, axis=-1)
+
+ if layer_head_mask is not None:
+ tf.debugging.assert_equal(
+ shape_list(layer_head_mask),
+ [self.num_heads],
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
+ )
+
+ attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
+ attn_weights, (bsz, self.num_heads, tgt_len, src_len)
+ )
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+ attn_probs = self.dropout(attn_weights, training=training)
+ attn_output = tf.matmul(attn_probs, value_states)
+
+ tf.debugging.assert_equal(
+ shape_list(attn_output),
+ [bsz * self.num_heads, tgt_len, self.head_dim],
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
+ )
+
+ attn_output = tf.transpose(
+ tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
+ )
+ attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
+
+ attn_output = self.out_proj(attn_output)
+ attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
+
+ return attn_output, attn_weights, past_key_value
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "k_proj", None) is not None:
+ with tf.name_scope(self.k_proj.name):
+ self.k_proj.build([None, None, self.embed_dim])
+ if getattr(self, "q_proj", None) is not None:
+ with tf.name_scope(self.q_proj.name):
+ self.q_proj.build([None, None, self.embed_dim])
+ if getattr(self, "v_proj", None) is not None:
+ with tf.name_scope(self.v_proj.name):
+ self.v_proj.build([None, None, self.embed_dim])
+ if getattr(self, "out_proj", None) is not None:
+ with tf.name_scope(self.out_proj.name):
+ self.out_proj.build([None, None, self.embed_dim])
+
+
+# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartEncoderLayer with MBart->Pegasus
+class TFPegasusEncoderLayer(keras.layers.Layer):
+ def __init__(self, config: PegasusConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.embed_dim = config.d_model
+ self.self_attn = TFPegasusAttention(
+ self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
+ )
+ self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
+ self.dropout = keras.layers.Dropout(config.dropout)
+ self.activation_fn = get_tf_activation(config.activation_function)
+ self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
+ self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
+ self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
+ self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
+ self.config = config
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ layer_head_mask: tf.Tensor,
+ training: bool | None = False,
+ ):
+ """
+ Args:
+ hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
+ attention_mask (`tf.Tensor`): attention mask of size
+ *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
+ layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
+ *(encoder_attention_heads,)*
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, self_attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
+ )
+
+ tf.debugging.assert_equal(
+ shape_list(hidden_states),
+ shape_list(residual),
+ message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
+ )
+
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = self.activation_dropout(hidden_states, training=training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ return hidden_states, self_attn_weights
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "self_attn", None) is not None:
+ with tf.name_scope(self.self_attn.name):
+ self.self_attn.build(None)
+ if getattr(self, "self_attn_layer_norm", None) is not None:
+ with tf.name_scope(self.self_attn_layer_norm.name):
+ self.self_attn_layer_norm.build([None, None, self.embed_dim])
+ if getattr(self, "fc1", None) is not None:
+ with tf.name_scope(self.fc1.name):
+ self.fc1.build([None, None, self.embed_dim])
+ if getattr(self, "fc2", None) is not None:
+ with tf.name_scope(self.fc2.name):
+ self.fc2.build([None, None, self.config.encoder_ffn_dim])
+ if getattr(self, "final_layer_norm", None) is not None:
+ with tf.name_scope(self.final_layer_norm.name):
+ self.final_layer_norm.build([None, None, self.embed_dim])
+
+
+# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartDecoderLayer with MBart->Pegasus
+class TFPegasusDecoderLayer(keras.layers.Layer):
+ def __init__(self, config: PegasusConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.embed_dim = config.d_model
+ self.self_attn = TFPegasusAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ name="self_attn",
+ is_decoder=True,
+ )
+ self.dropout = keras.layers.Dropout(config.dropout)
+ self.activation_fn = get_tf_activation(config.activation_function)
+ self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
+
+ self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
+ self.encoder_attn = TFPegasusAttention(
+ self.embed_dim,
+ config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ name="encoder_attn",
+ is_decoder=True,
+ )
+ self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
+ self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
+ self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
+ self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
+ self.config = config
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor | None = None,
+ encoder_hidden_states: tf.Tensor | None = None,
+ encoder_attention_mask: tf.Tensor | None = None,
+ layer_head_mask: tf.Tensor | None = None,
+ cross_attn_layer_head_mask: tf.Tensor | None = None,
+ past_key_value: tuple[tf.Tensor] | None = None,
+ training: bool | None = False,
+ ) -> tuple[tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]:
+ """
+ Args:
+ hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
+ attention_mask (`tf.Tensor`): attention mask of size
+ *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`tf.Tensor`):
+ cross attention input to the layer of shape *(batch, seq_len, embed_dim)*
+ encoder_attention_mask (`tf.Tensor`): encoder attention mask of size
+ *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
+ layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
+ *(decoder_attention_heads,)*
+ cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module.
+ *(decoder_attention_heads,)*
+ past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ )
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ # Cross-Attention Block
+ cross_attn_present_key_value = None
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ )
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ # add cross-attn to positions 3,4 of present_key_value tuple
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = self.activation_dropout(hidden_states, training=training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ return (
+ hidden_states,
+ self_attn_weights,
+ cross_attn_weights,
+ present_key_value,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "self_attn", None) is not None:
+ with tf.name_scope(self.self_attn.name):
+ self.self_attn.build(None)
+ if getattr(self, "self_attn_layer_norm", None) is not None:
+ with tf.name_scope(self.self_attn_layer_norm.name):
+ self.self_attn_layer_norm.build([None, None, self.embed_dim])
+ if getattr(self, "encoder_attn", None) is not None:
+ with tf.name_scope(self.encoder_attn.name):
+ self.encoder_attn.build(None)
+ if getattr(self, "encoder_attn_layer_norm", None) is not None:
+ with tf.name_scope(self.encoder_attn_layer_norm.name):
+ self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
+ if getattr(self, "fc1", None) is not None:
+ with tf.name_scope(self.fc1.name):
+ self.fc1.build([None, None, self.embed_dim])
+ if getattr(self, "fc2", None) is not None:
+ with tf.name_scope(self.fc2.name):
+ self.fc2.build([None, None, self.config.decoder_ffn_dim])
+ if getattr(self, "final_layer_norm", None) is not None:
+ with tf.name_scope(self.final_layer_norm.name):
+ self.final_layer_norm.build([None, None, self.embed_dim])
+
+
+class TFPegasusPreTrainedModel(TFPreTrainedModel):
+ config_class = PegasusConfig
+ base_model_prefix = "model"
+
+
+PEGASUS_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Args:
+ config ([`PegasusConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+PEGASUS_GENERATION_EXAMPLE = r"""
+ Summarization example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, TFPegasusForConditionalGeneration
+
+ >>> model = TFPegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum")
+
+ >>> ARTICLE_TO_SUMMARIZE = (
+ ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
+ ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
+ ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
+ ... )
+ >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="tf")
+
+ >>> # Generate Summary
+ >>> summary_ids = model.generate(input_ids)
+ >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
+ ```
+"""
+
+PEGASUS_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`tf.Tensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+ decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
+ decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.max_position_embeddings - 1]`.
+ head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ encoder_outputs (`tf.FloatTensor`, *optional*):
+ hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ of shape `(batch_size, sequence_length, hidden_size)` is a sequence of
+ past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`)
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`). Set to `False` during training, `True` during generation output_attentions (`bool`,
+ *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions`
+ under returned tensors for more detail. This argument can be used only in eager mode, in graph mode the
+ value in the config will be used instead.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@keras_serializable
+class TFPegasusEncoder(keras.layers.Layer):
+ config_class = PegasusConfig
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`TFPegasusEncoderLayer`].
+
+ Args:
+ config: PegasusConfig
+ """
+
+ def __init__(self, config: PegasusConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.dropout = keras.layers.Dropout(config.dropout)
+ self.layerdrop = config.encoder_layerdrop
+ self.padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_position_embeddings
+ self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
+
+ self.embed_tokens = embed_tokens
+ self.embed_positions = TFPegasusSinusoidalPositionalEmbedding(
+ config.max_position_embeddings,
+ config.d_model,
+ name="embed_positions",
+ )
+ self.layers = [TFPegasusEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
+
+ def get_embed_tokens(self):
+ return self.embed_tokens
+
+ def set_embed_tokens(self, embed_tokens):
+ self.embed_tokens = embed_tokens
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ attention_mask: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ):
+ """
+ Args:
+ input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value
+ in the config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail. This argument can be used only in eager mode, in graph mode the value in the config
+ will be used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used
+ in eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+ """
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ embed_pos = self.embed_positions(input_shape)
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = self.dropout(hidden_states, training=training)
+
+ # check attention mask and invert
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask)
+ else:
+ attention_mask = None
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ tf.debugging.assert_equal(
+ shape_list(head_mask)[0],
+ len(self.layers),
+ message=(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(head_mask)[0]}."
+ ),
+ )
+
+ # encoder layers
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ dropout_probability = random.uniform(0, 1)
+ if training and (dropout_probability < self.layerdrop): # skip the layer
+ continue
+
+ hidden_states, attn = encoder_layer(
+ hidden_states,
+ attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ )
+
+ if output_attentions:
+ all_attentions += (attn,)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embed_positions", None) is not None:
+ with tf.name_scope(self.embed_positions.name):
+ self.embed_positions.build(None)
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.d_model])
+ if getattr(self, "layers", None) is not None:
+ for layer in self.layers:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+@keras_serializable
+class TFPegasusDecoder(keras.layers.Layer):
+ config_class = PegasusConfig
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFPegasusDecoderLayer`]
+
+ Args:
+ config: PegasusConfig
+ embed_tokens: output embedding
+ """
+
+ def __init__(self, config: PegasusConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.embed_tokens = embed_tokens
+ self.layerdrop = config.decoder_layerdrop
+ self.embed_positions = TFPegasusSinusoidalPositionalEmbedding(
+ config.max_position_embeddings,
+ config.d_model,
+ name="embed_positions",
+ )
+ self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
+ self.layers = [TFPegasusDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
+
+ self.dropout = keras.layers.Dropout(config.dropout)
+
+ def get_embed_tokens(self):
+ return self.embed_tokens
+
+ def set_embed_tokens(self, embed_tokens):
+ self.embed_tokens = embed_tokens
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ attention_mask: tf.Tensor | None = None,
+ position_ids: tf.Tensor | None = None,
+ encoder_hidden_states: tf.Tensor | None = None,
+ encoder_attention_mask: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ cross_attn_head_mask: tf.Tensor | None = None,
+ past_key_values: tuple[tuple[tf.Tensor]] | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ):
+ r"""
+ Args:
+ input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.max_position_embeddings - 1]`.
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
+ decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value
+ in the config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail. This argument can be used only in eager mode, in graph mode the value in the config
+ will be used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used
+ in eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+ """
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
+
+ # embed positions
+ if position_ids is None:
+ positions = self.embed_positions(input_shape, past_key_values_length)
+ else:
+ positions = self.embed_positions(input_shape, position_ids=position_ids)
+
+ if inputs_embeds is None:
+ check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ hidden_states = inputs_embeds
+
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
+ else:
+ combined_attention_mask = _expand_mask(
+ tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
+ )
+
+ if attention_mask is not None:
+ combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
+
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
+
+ hidden_states = self.dropout(hidden_states + positions, training=training)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None
+ present_key_values = () if use_cache else None
+
+ # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
+ for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
+ if attn_mask is not None:
+ tf.debugging.assert_equal(
+ shape_list(attn_mask)[0],
+ len(self.layers),
+ message=(
+ f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(attn_mask)[0]}."
+ ),
+ )
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ dropout_probability = random.uniform(0, 1)
+
+ if training and (dropout_probability < self.layerdrop):
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
+ hidden_states,
+ attention_mask=combined_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=head_mask[idx] if head_mask is not None else None,
+ cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ past_key_value=past_key_value,
+ )
+
+ if use_cache:
+ present_key_values += (present_key_value,)
+
+ if output_attentions:
+ all_self_attns += (layer_self_attn,)
+
+ if encoder_hidden_states is not None:
+ all_cross_attns += (layer_cross_attn,)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
+ else:
+ return TFBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=present_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attns,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embed_positions", None) is not None:
+ with tf.name_scope(self.embed_positions.name):
+ self.embed_positions.build(None)
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.d_model])
+ if getattr(self, "layers", None) is not None:
+ for layer in self.layers:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+@keras_serializable
+class TFPegasusMainLayer(keras.layers.Layer):
+ config_class = PegasusConfig
+
+ def __init__(self, config: PegasusConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.shared = keras.layers.Embedding(
+ input_dim=config.vocab_size,
+ output_dim=config.d_model,
+ embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std),
+ name="model.shared",
+ )
+ # Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
+ self.shared.load_weight_prefix = "model.shared"
+
+ self.encoder = TFPegasusEncoder(config, self.shared, name="encoder")
+ self.decoder = TFPegasusDecoder(config, self.shared, name="decoder")
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, new_embeddings):
+ self.shared = new_embeddings
+ self.encoder.embed_tokens = self.shared
+ self.decoder.embed_tokens = self.shared
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: tf.Tensor | None = None,
+ attention_mask: tf.Tensor | None = None,
+ decoder_input_ids: tf.Tensor | None = None,
+ decoder_attention_mask: tf.Tensor | None = None,
+ decoder_position_ids: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ decoder_head_mask: tf.Tensor | None = None,
+ cross_attn_head_mask: tf.Tensor | None = None,
+ encoder_outputs: tuple | TFBaseModelOutput | None = None,
+ past_key_values: tuple[tuple[tf.Tensor]] | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ decoder_inputs_embeds: tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ **kwargs,
+ ):
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ use_cache = False
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
+ encoder_outputs = TFBaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+ # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
+ elif not return_dict and not isinstance(encoder_outputs, tuple):
+ encoder_outputs = encoder_outputs.to_tuple()
+
+ decoder_outputs = self.decoder(
+ decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ position_ids=decoder_position_ids,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return TFSeq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ # The shared/tied weights expect to be in the model base namespace
+ # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
+ # the current one.
+ with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
+ self.shared.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "decoder", None) is not None:
+ with tf.name_scope(self.decoder.name):
+ self.decoder.build(None)
+
+
+@add_start_docstrings(
+ "The bare PEGASUS Model outputting raw hidden-states without any specific head on top.",
+ PEGASUS_START_DOCSTRING,
+)
+class TFPegasusModel(TFPegasusPreTrainedModel):
+ def __init__(self, config: PegasusConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.model = TFPegasusMainLayer(config, name="model")
+
+ def get_encoder(self):
+ return self.model.encoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFSeq2SeqModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_input_ids: np.ndarray | tf.Tensor | None = None,
+ decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_head_mask: np.ndarray | tf.Tensor | None = None,
+ cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
+ encoder_outputs: tuple | TFBaseModelOutput | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ **kwargs,
+ ) -> TFSeq2SeqModelOutput | tuple[tf.Tensor]:
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ decoder_position_ids=decoder_position_ids,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ encoder_outputs=encoder_outputs,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+ # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
+ def serving_output(self, output):
+ pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
+ dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
+ dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
+ cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
+ enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
+ enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
+
+ return TFSeq2SeqModelOutput(
+ last_hidden_state=output.last_hidden_state,
+ past_key_values=pkv,
+ decoder_hidden_states=dec_hs,
+ decoder_attentions=dec_attns,
+ cross_attentions=cross_attns,
+ encoder_last_hidden_state=output.encoder_last_hidden_state,
+ encoder_hidden_states=enc_hs,
+ encoder_attentions=enc_attns,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "model", None) is not None:
+ with tf.name_scope(self.model.name):
+ self.model.build(None)
+
+
+# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
+class BiasLayer(keras.layers.Layer):
+ """
+ Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis,
+ so all weights have to be registered in a layer.
+ """
+
+ def __init__(self, shape, initializer, trainable, name, **kwargs):
+ super().__init__(name=name, **kwargs)
+ # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
+ # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
+ # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
+ self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
+
+ def call(self, x):
+ return x + self.bias
+
+
+@add_start_docstrings(
+ "The PEGASUS Model with a language modeling head. Can be used for summarization.",
+ PEGASUS_START_DOCSTRING,
+)
+class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLanguageModelingLoss):
+ _keys_to_ignore_on_load_unexpected = [
+ r"model.encoder.embed_tokens.weight",
+ r"model.decoder.embed_tokens.weight",
+ ]
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.model = TFPegasusMainLayer(config, name="model")
+ self.use_cache = config.use_cache
+ # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
+ self.bias_layer = BiasLayer(
+ name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
+ )
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ def get_encoder(self):
+ return self.model.encoder
+
+ def get_output_embeddings(self):
+ return self.get_input_embeddings()
+
+ def set_output_embeddings(self, value):
+ self.set_input_embeddings(value)
+
+ def get_bias(self):
+ return {"final_logits_bias": self.bias_layer.bias}
+
+ def set_bias(self, value):
+ # Replaces the existing layers containing bias for correct (de)serialization.
+ vocab_size = value["final_logits_bias"].shape[-1]
+ self.bias_layer = BiasLayer(
+ name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
+ )
+ self.bias_layer.bias.assign(value["final_logits_bias"])
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ @add_end_docstrings(PEGASUS_GENERATION_EXAMPLE)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_input_ids: np.ndarray | tf.Tensor | None = None,
+ decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_head_mask: np.ndarray | tf.Tensor | None = None,
+ cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
+ encoder_outputs: TFBaseModelOutput | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool = False,
+ ) -> TFSeq2SeqLMOutput | tuple[tf.Tensor]:
+ """
+ labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ """
+
+ if labels is not None:
+ labels = tf.where(
+ labels == self.config.pad_token_id,
+ tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
+ labels,
+ )
+ use_cache = False
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ decoder_position_ids=decoder_position_ids,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
+ lm_logits = self.bias_layer(lm_logits)
+ masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+ return TFSeq2SeqLMOutput(
+ loss=masked_lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values, # index 1 of d outputs
+ decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
+ decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
+ cross_attentions=outputs.cross_attentions, # index 4 of d outputs
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
+ encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
+ encoder_attentions=outputs.encoder_attentions, # 2 of e out
+ )
+
+ # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
+ def serving_output(self, output):
+ pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
+ dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
+ dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
+ cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
+ enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
+ enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
+
+ return TFSeq2SeqLMOutput(
+ logits=output.logits,
+ past_key_values=pkv,
+ decoder_hidden_states=dec_hs,
+ decoder_attentions=dec_attns,
+ cross_attentions=cross_attns,
+ encoder_last_hidden_state=output.encoder_last_hidden_state,
+ encoder_hidden_states=enc_hs,
+ encoder_attentions=enc_attns,
+ )
+
+ # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ decoder_attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ **kwargs,
+ ):
+ # cut decoder_input_ids if past_key_values is used
+ if past_key_values is not None:
+ decoder_input_ids = decoder_input_ids[:, -1:]
+
+ if decoder_attention_mask is not None: # xla
+ decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
+ elif past_key_values is not None: # no xla + past_key_values
+ decoder_position_ids = past_key_values[0][0].shape[2]
+ else: # no xla + no past_key_values
+ decoder_position_ids = tf.range(decoder_input_ids.shape[1])
+
+ return {
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ "decoder_position_ids": decoder_position_ids,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "model", None) is not None:
+ with tf.name_scope(self.model.name):
+ self.model.build(None)
+ if getattr(self, "bias_layer", None) is not None:
+ with tf.name_scope(self.bias_layer.name):
+ self.bias_layer.build(None)
+
+
+__all__ = ["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/tokenization_pegasus.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/tokenization_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8a4a1c737d1592d9362f506e93b20854edfe7a5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/tokenization_pegasus.py
@@ -0,0 +1,292 @@
+# coding=utf-8
+# Copyright 2020 Google and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from shutil import copyfile
+from typing import Any, Optional
+
+import sentencepiece as spm
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+from ...utils.import_utils import requires
+
+
+SPIECE_UNDERLINE = "▁"
+
+VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
+
+
+logger = logging.get_logger(__name__)
+
+
+# TODO ArthurZ refactor this to only use the added_tokens_encoder
+
+
+@requires(backends=("sentencepiece",))
+class PegasusTokenizer(PreTrainedTokenizer):
+ r"""
+ Construct a PEGASUS tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking single token values. This is the token used when training this model with masked
+ language modeling (MLM). This is the token that the PEGASUS encoder will try to predict during pretraining.
+ It corresponds to *[MASK2]* in [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive
+ Summarization](https://huggingface.co/papers/1912.08777).
+ mask_token_sent (`str`, *optional*, defaults to `""`):
+ The token used for masking whole target sentences. This is the token used when training this model with gap
+ sentences generation (GSG). This is the sentence that the PEGASUS decoder will try to predict during
+ pretraining. It corresponds to *[MASK1]* in [PEGASUS: Pre-training with Extracted Gap-sentences for
+ Abstractive Summarization](https://huggingface.co/papers/1912.08777).
+ additional_special_tokens (`List[str]`, *optional*):
+ Additional special tokens used by the tokenizer. If no additional_special_tokens are provided and
+ are used as additional special tokens corresponding to the [original PEGASUS
+ tokenizer](https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66)
+ that uses the tokens 2 - 104 only for pretraining
+ sp_model_kwargs (`dict`, *optional*):
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+ to set:
+
+ - `enable_sampling`: Enable subword regularization.
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+ - `nbest_size = {0,1}`: No sampling is performed.
+ - `nbest_size > 1`: samples from the nbest_size results.
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+ using forward-filtering-and-backward-sampling algorithm.
+
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+ BPE-dropout.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ pad_token="",
+ eos_token="",
+ unk_token="",
+ mask_token="",
+ mask_token_sent="",
+ additional_special_tokens=None,
+ offset=103, # entries 2 - 104 are only used for pretraining
+ sp_model_kwargs: Optional[dict[str, Any]] = None,
+ **kwargs,
+ ) -> None:
+ self.offset = offset
+ if additional_special_tokens is not None:
+ if not isinstance(additional_special_tokens, list):
+ raise TypeError(
+ f"additional_special_tokens should be of type {type(list)}, but is"
+ f" {type(additional_special_tokens)}"
+ )
+ additional_special_tokens_extended = (
+ ([mask_token_sent] + additional_special_tokens)
+ if mask_token_sent not in additional_special_tokens and mask_token_sent is not None
+ else additional_special_tokens
+ )
+ # fill additional tokens with ..., in case not all additional tokens are already taken
+ additional_special_tokens_extended += [
+ f"" for i in range(len(additional_special_tokens_extended), self.offset - 1)
+ ]
+
+ if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended):
+ raise ValueError(
+ "Please make sure that the provided additional_special_tokens do not contain an incorrectly"
+ f" shifted list of tokens. Found {additional_special_tokens_extended}."
+ )
+ additional_special_tokens = additional_special_tokens_extended
+ else:
+ additional_special_tokens_extended = []
+ additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else []
+ additional_special_tokens += [f"" for i in range(2, self.offset)]
+
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+ self.mask_token_sent = mask_token_sent
+ self.vocab_file = vocab_file
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(vocab_file)
+
+ _added_tokens_decoder = {
+ 0: AddedToken(str(pad_token), special=True),
+ 1: AddedToken(str(eos_token), special=True),
+ }
+
+ if self.mask_token_sent is not None:
+ _added_tokens_decoder[2] = AddedToken(mask_token_sent, special=True)
+ _added_tokens_decoder[3] = AddedToken(str(mask_token), special=True)
+
+ for i in range(2, self.offset):
+ _added_tokens_decoder[len(_added_tokens_decoder)] = AddedToken(f"", special=True)
+
+ # Force update as we want to make sure vocab is enforced (same as fast)
+ self._added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
+ self._added_tokens_decoder.update(_added_tokens_decoder)
+
+ super().__init__(
+ eos_token=eos_token,
+ unk_token=unk_token,
+ mask_token=mask_token,
+ pad_token=pad_token,
+ mask_token_sent=mask_token_sent,
+ offset=offset,
+ additional_special_tokens=additional_special_tokens,
+ sp_model_kwargs=self.sp_model_kwargs,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self) -> int:
+ return len(self.sp_model) + self.offset
+
+ def get_vocab(self) -> dict[str, int]:
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+
+ # for backward compatibility
+ if not hasattr(self, "sp_model_kwargs"):
+ self.sp_model_kwargs = {}
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(self.vocab_file)
+
+ def _tokenize(self, text: str) -> list[str]:
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
+ return self.sp_model.encode(text, out_type=str)
+
+ def _convert_token_to_id(self, token: str) -> int:
+ """Converts a token (str) to an id using the vocab."""
+ sp_id = self.sp_model.piece_to_id(token)
+ return sp_id + self.offset
+
+ def _convert_id_to_token(self, index: int) -> str:
+ """Converts an index (integer) to a token (str) using the vocab."""
+ if index < self.offset:
+ return self.sp_model.IdToPiece(index)
+ token = self.sp_model.IdToPiece(index - self.offset)
+ return token
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ current_sub_tokens = []
+ out_string = ""
+ for token in tokens:
+ # make sure that special tokens are not decoded using sentencepiece model
+ if token in self.all_special_tokens:
+ out_string += self.sp_model.decode(current_sub_tokens) + token
+ current_sub_tokens = []
+ else:
+ current_sub_tokens.append(token)
+ out_string += self.sp_model.decode(current_sub_tokens)
+ return out_string.strip()
+
+ def num_special_tokens_to_add(self, pair=False):
+ """Just EOS"""
+ return 1
+
+ def _special_token_mask(self, seq):
+ all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
+ all_special_ids.remove(self.unk_token_id) # is only sometimes special
+
+ return [1 if x in all_special_ids else 0 for x in seq]
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list, token_ids_1: Optional[list] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """Get list where entries are [1] if a token is [eos] or [pad] else 0."""
+ if already_has_special_tokens:
+ return self._special_token_mask(token_ids_0)
+ elif token_ids_1 is None:
+ return self._special_token_mask(token_ids_0) + [1]
+ else:
+ return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating
+ and adding special tokens. A PEGASUS sequence has the following format, where `X` represents the sequence:
+
+ - single sequence: `X `
+ - pair of sequences: `A B ` (not intended use)
+
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+ separator.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return token_ids_0 + [self.eos_token_id]
+ # We don't expect to process pairs, but leave the pair logic for API consistency
+ return token_ids_0 + token_ids_1 + [self.eos_token_id]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+
+__all__ = ["PegasusTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/tokenization_pegasus_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/tokenization_pegasus_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..92a37c44ff2e302fe1e2a56849f2e91fa481ec9b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus/tokenization_pegasus_fast.py
@@ -0,0 +1,215 @@
+# coding=utf-8
+# Copyright 2020 Google and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for model PEGASUS."""
+
+import os
+from shutil import copyfile
+from typing import Optional
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import is_sentencepiece_available, logging
+
+
+if is_sentencepiece_available():
+ from .tokenization_pegasus import PegasusTokenizer
+else:
+ PegasusTokenizer = None
+
+
+logger = logging.get_logger(__name__)
+
+
+SPIECE_UNDERLINE = "▁"
+
+VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
+
+
+class PegasusTokenizerFast(PreTrainedTokenizerFast):
+ r"""
+ Construct a "fast" PEGASUS tokenizer (backed by HuggingFace's *tokenizers* library). Based on
+ [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models).
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking single token values. This is the token used when training this model with masked
+ language modeling (MLM). This is the token that the PEGASUS encoder will try to predict during pretraining.
+ It corresponds to *[MASK2]* in [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive
+ Summarization](https://huggingface.co/papers/1912.08777).
+ mask_token_sent (`str`, *optional*, defaults to `""`):
+ The token used for masking whole target sentences. This is the token used when training this model with gap
+ sentences generation (GSG). This is the sentence that the PEGASUS decoder will try to predict during
+ pretraining. It corresponds to *[MASK1]* in [PEGASUS: Pre-training with Extracted Gap-sentences for
+ Abstractive Summarization](https://huggingface.co/papers/1912.08777).
+ additional_special_tokens (`List[str]`, *optional*):
+ Additional special tokens used by the tokenizer. If no additional_special_tokens are provided and
+ are used as additional special tokens corresponding to the [original PEGASUS
+ tokenizer](https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66)
+ that uses the tokens 2 - 104 only for pretraining
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = PegasusTokenizer
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ pad_token="",
+ eos_token="",
+ unk_token="",
+ mask_token="",
+ mask_token_sent="",
+ additional_special_tokens=None,
+ offset=103, # entries 2 - 104 are only used for pretraining
+ **kwargs,
+ ):
+ self.offset = offset
+
+ if additional_special_tokens is not None:
+ if not isinstance(additional_special_tokens, list):
+ raise TypeError(
+ f"additional_special_tokens should be of type {type(list)}, but is"
+ f" {type(additional_special_tokens)}"
+ )
+
+ additional_special_tokens_extended = (
+ ([mask_token_sent] + additional_special_tokens)
+ if mask_token_sent not in additional_special_tokens and mask_token_sent is not None
+ else additional_special_tokens
+ )
+ # fill additional tokens with ..., in case not all additional tokens are already taken
+ additional_special_tokens_extended += [
+ f"" for i in range(len(additional_special_tokens_extended), self.offset - 1)
+ ]
+
+ if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended):
+ raise ValueError(
+ "Please make sure that the provided additional_special_tokens do not contain an incorrectly"
+ f" shifted list of tokens. Found {additional_special_tokens_extended}."
+ )
+ additional_special_tokens = additional_special_tokens_extended
+ else:
+ additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else []
+ additional_special_tokens += [f"" for i in range(2, self.offset)]
+
+ # pegasus was design to support changing the index of the first tokens. If one of the padding/eos/unk/mask token
+ # is different from default, we must rebuild the vocab
+ from_slow = kwargs.pop("from_slow", None)
+ from_slow = from_slow or str(pad_token) != "" or str(eos_token) != "" or str(unk_token) != ""
+
+ kwargs.pop("added_tokens_decoder", {})
+
+ super().__init__(
+ vocab_file,
+ tokenizer_file=tokenizer_file,
+ pad_token=pad_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ mask_token=mask_token,
+ mask_token_sent=mask_token_sent,
+ offset=offset,
+ additional_special_tokens=additional_special_tokens,
+ from_slow=from_slow,
+ **kwargs,
+ )
+ self.vocab_file = vocab_file
+
+ def _special_token_mask(self, seq):
+ all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
+ all_special_ids.remove(self.unk_token_id) # is only sometimes special
+
+ if all_special_ids != set(range(len(self.additional_special_tokens) + 3)):
+ raise ValueError(
+ "There should be 3 special tokens: mask_token, pad_token, and eos_token +"
+ f" {len(self.additional_special_tokens)} additional_special_tokens, but got {all_special_ids}"
+ )
+
+ return [1 if x in all_special_ids else 0 for x in seq]
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list, token_ids_1: Optional[list] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """Get list where entries are [1] if a token is [eos] or [pad] else 0."""
+ if already_has_special_tokens:
+ return self._special_token_mask(token_ids_0)
+ elif token_ids_1 is None:
+ return self._special_token_mask(token_ids_0) + [1]
+ else:
+ return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> list[int]:
+ """
+ Build model inputs from a sequence by adding eos to the end. no bos token is added to the front.
+
+ - single sequence: `X `
+ - pair of sequences: `A B ` (not intended use)
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return token_ids_0 + [self.eos_token_id]
+ # We don't expect to process pairs, but leave the pair logic for API consistency
+ return token_ids_0 + token_ids_1 + [self.eos_token_id]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not self.can_save_slow_tokenizer:
+ raise ValueError(
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+ "tokenizer."
+ )
+
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
+
+
+__all__ = ["PegasusTokenizerFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus_x/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus_x/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d362bacf491ea32dcbc89e7b8bba4d9ae1a9261
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus_x/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_pegasus_x import *
+ from .modeling_pegasus_x import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus_x/configuration_pegasus_x.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus_x/configuration_pegasus_x.py
new file mode 100644
index 0000000000000000000000000000000000000000..626389c448b86f64995b825910bde4b1bde543cb
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus_x/configuration_pegasus_x.py
@@ -0,0 +1,177 @@
+# coding=utf-8
+# Copyright 2022, Google and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PEGASUS-X model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class PegasusXConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PegasusXModel`]. It is used to instantiate a
+ PEGASUS-X model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the PEGASUS-X
+ [google/pegasus-x-large](https://huggingface.co/google/pegasus-x-large) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 96103):
+ Vocabulary size of the PEGASUS-X model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`PegasusXModel`].
+ d_model (`int`, *optional*, defaults to 1024):
+ Dimension of the layers and the pooler layer.
+ encoder_layers (`int`, *optional*, defaults to 16):
+ Number of encoder layers.
+ decoder_layers (`int`, *optional*, defaults to 16):
+ Number of decoder layers.
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ max_position_embeddings (`int`, *optional*, defaults to 16384):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
+ for more details.
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
+ for more details.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models)
+ forced_eos_token_id (`int`, *optional*, defaults to 1):
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
+ `eos_token_id`.
+ num_global_tokens (`int`, *optional*, defaults to 128):
+ Number of global tokens to use for the encoder
+ block_size (`int`, *optional*, defaults to 512):
+ Block size for encoder local attention. Sequence length should be an exact multiple of block size.
+ block_size must be a multiple of 2 if stagger_local_block is True
+ stagger_local_block (`bool`, *optional*, defaults to `True`):
+ Whether to stagger every other local attention by half a block
+
+ Example:
+
+ ```python
+ >>> from transformers import PegasusXConfig, PegasusXModel
+
+ >>> # Initializing a PEGASUS google/pegasus-x-large style configuration
+ >>> configuration = PegasusXConfig()
+
+ >>> # Initializing a model (with random weights) from the google/pegasus-x-large style configuration
+ >>> model = PegasusXModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "pegasus_x"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+ def __init__(
+ self,
+ vocab_size=96103,
+ max_position_embeddings=16384,
+ encoder_layers=16,
+ encoder_ffn_dim=4096,
+ encoder_attention_heads=16,
+ decoder_layers=16,
+ decoder_ffn_dim=4096,
+ decoder_attention_heads=16,
+ encoder_layerdrop=0.0,
+ decoder_layerdrop=0.0,
+ use_cache=True,
+ is_encoder_decoder=True,
+ activation_function="gelu",
+ d_model=1024,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ init_std=0.02,
+ decoder_start_token_id=0,
+ scale_embedding=True,
+ pad_token_id=0,
+ eos_token_id=1,
+ forced_eos_token_id=1,
+ num_global_tokens=32,
+ block_size=512,
+ stagger_local_blocks=True,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.d_model = d_model
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.encoder_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.decoder_layers = decoder_layers
+ self.decoder_attention_heads = decoder_attention_heads
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.encoder_layerdrop = encoder_layerdrop
+ self.decoder_layerdrop = decoder_layerdrop
+ self.use_cache = use_cache
+ self.num_hidden_layers = encoder_layers
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+
+ self.num_global_tokens = num_global_tokens
+ self.block_size = block_size
+ self.stagger_local_blocks = stagger_local_blocks
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ decoder_start_token_id=decoder_start_token_id,
+ forced_eos_token_id=forced_eos_token_id,
+ **kwargs,
+ )
+
+ @property
+ def num_attention_heads(self) -> int:
+ return self.encoder_attention_heads
+
+ @property
+ def hidden_size(self) -> int:
+ return self.d_model
+
+
+__all__ = ["PegasusXConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus_x/modeling_pegasus_x.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus_x/modeling_pegasus_x.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c1ae32cabe26ae9f05811731ad7deabe6b82bdd
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pegasus_x/modeling_pegasus_x.py
@@ -0,0 +1,1717 @@
+# coding=utf-8
+# Copyright 2022, Google and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch PEGASUS-X model."""
+
+import math
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import (
+ AttentionMaskConverter,
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_pegasus_x import PegasusXConfig
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class DimensionInfo:
+ """Wrapper for dimension info."""
+
+ batch_size: int # batch size
+ seq_len: int # token length
+ block_size: int # block size
+ num_heads: int # num heads
+ hidden_dim: int # hidden dim
+ dim_per_head: int # dim per head
+ num_blocks: int # num blocks
+ global_len: int # global length
+ padded_seq_len: int # padded token seq length
+
+ # Note: Compared to the original Flax implementation, we will pad the token representations to
+ # a multiple of block size at the start of the encoder layers, so T=P always.
+
+
+# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PegasusX
+class PegasusXScaledWordEmbedding(nn.Embedding):
+ """
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
+ self.embed_scale = embed_scale
+
+ def forward(self, input_ids: torch.Tensor):
+ return super().forward(input_ids) * self.embed_scale
+
+
+class PegasusXSinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length."""
+
+ def __init__(self, embed_dim, max_scale: int = 10000.0):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.max_scale = max_scale
+
+ @torch.no_grad()
+ def forward(
+ self, input_embeds: torch.Tensor, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+ batch_size, seq_len = input_embeds.shape[:2]
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=input_embeds.device
+ )[:, None]
+
+ pe = torch.zeros((seq_len, self.embed_dim), device=input_embeds.device, dtype=input_embeds.dtype)
+ half_d_feature = self.embed_dim // 2
+ div_term = torch.exp(
+ torch.arange(half_d_feature, device=input_embeds.device, dtype=torch.int64).type_as(input_embeds)
+ * -(np.log(float(self.max_scale)) / (half_d_feature - 1))
+ )
+ pe[:, :half_d_feature] = torch.sin(position_ids * div_term)
+ pe[:, half_d_feature:] = torch.cos(position_ids * div_term)
+ return pe[None].expand(batch_size, -1, -1)
+
+
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PegasusX
+class PegasusXAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ is_causal: bool = False,
+ config: Optional[PegasusXConfig] = None,
+ layer_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+ self.layer_idx = layer_idx
+ if layer_idx is None and self.is_decoder:
+ logger.warning_once(
+ f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
+ "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
+ is_updated = False
+ if past_key_values is not None:
+ if isinstance(past_key_values, EncoderDecoderCache):
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = past_key_values.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_values.self_attention_cache
+ else:
+ curr_past_key_value = past_key_values
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_values is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
+ value_states = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_states = self.k_proj(current_states)
+ value_states = self.v_proj(current_states)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
+
+ if past_key_values is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
+ past_key_values.is_updated[self.layer_idx] = True
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class PegasusXGlobalLocalAttention(nn.Module):
+ """Global + Local attention. For use with Encoder only."""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ block_size: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.block_size = block_size
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ token_hidden_states: torch.Tensor,
+ global_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+ dim = DimensionInfo(
+ batch_size=token_hidden_states.shape[0],
+ seq_len=token_hidden_states.shape[1],
+ block_size=self.block_size,
+ num_heads=self.num_heads,
+ hidden_dim=token_hidden_states.shape[2],
+ dim_per_head=self.head_dim,
+ num_blocks=token_hidden_states.shape[1] // self.block_size,
+ global_len=global_hidden_states.shape[1],
+ padded_seq_len=token_hidden_states.shape[1],
+ )
+
+ # [batch_size, num_heads, padded_seq_len, dim_per_head]
+ local_q = self._shape(
+ self.q_proj(token_hidden_states) * self.scaling,
+ seq_len=dim.padded_seq_len,
+ bsz=dim.batch_size,
+ )
+ local_k = self._shape(
+ self.k_proj(token_hidden_states),
+ seq_len=dim.padded_seq_len,
+ bsz=dim.batch_size,
+ )
+ local_v = self._shape(
+ self.v_proj(token_hidden_states),
+ seq_len=dim.padded_seq_len,
+ bsz=dim.batch_size,
+ )
+
+ # [batch_size, num_heads, global_len, dim_per_head]
+ global_q = self._shape(
+ self.q_proj(global_hidden_states) * self.scaling,
+ seq_len=dim.global_len,
+ bsz=dim.batch_size,
+ )
+ global_k = self._shape(
+ self.k_proj(global_hidden_states),
+ seq_len=dim.global_len,
+ bsz=dim.batch_size,
+ )
+ global_v = self._shape(
+ self.v_proj(global_hidden_states),
+ seq_len=dim.global_len,
+ bsz=dim.batch_size,
+ )
+
+ global_attn_output, global_attn_probs = self.compute_global_attention_representations(
+ global_q=global_q,
+ global_k=global_k,
+ global_v=global_v,
+ local_k=local_k,
+ local_v=local_v,
+ mask=attention_mask,
+ dim=dim,
+ )
+ local_attn_output, local_attn_probs = self.compute_local_attention_representations(
+ global_k=global_k,
+ global_v=global_v,
+ local_q=local_q,
+ local_k=local_k,
+ local_v=local_v,
+ mask=attention_mask,
+ dim=dim,
+ )
+
+ # [batch_size, global_len, hidden_dim]
+ global_attn_output = (
+ global_attn_output.transpose(1, 2).contiguous().view(dim.batch_size, dim.global_len, dim.hidden_dim)
+ )
+ # [batch_size, global_len, hidden_dim]
+ global_attn_output = self.out_proj(global_attn_output)
+ # [batch_size, num_heads, block_size, num_heads, dim_per_head]
+ local_attn_output = local_attn_output.permute(0, 2, 3, 1, 4).contiguous()
+ # [batch_size, padded_seq_len, hidden_dim]
+ local_attn_output = local_attn_output.view(dim.batch_size, dim.padded_seq_len, dim.hidden_dim)
+ # [batch_size, padded_seq_len, hidden_dim]
+ local_attn_output = self.out_proj(local_attn_output)
+
+ if output_attentions:
+ attn_probs = {"global": global_attn_probs, "local": local_attn_probs}
+ else:
+ attn_probs = None
+
+ return local_attn_output, global_attn_output, attn_probs
+
+ def compute_global_attention_representations(
+ self, global_q, global_k, global_v, local_k, local_v, mask, dim: DimensionInfo
+ ):
+ """Compute attention representations for global tokens.
+
+ Global tokens will attend to both global tokens as well as all input sequence tokens. Because the input
+ sequence tokens are arranged in blocks for local attention, we unblock them and compute attention.
+
+ Args:
+ global_q (`torch.FloatTensor`) of shape [batch_size, num_heads, global_len, dim_per_head]:
+ query vectors from global tokens
+ global_k (`torch.FloatTensor`) of shape [batch_size, num_heads, global_len, dim_per_head]:
+ key vectors from global tokens
+ global_v (`torch.FloatTensor`) of shape [batch_size, num_heads, global_len, dim_per_head]:
+ value vectors from global tokens
+ local_k (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]:
+ key vectors from local tokens
+ local_v (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]:
+ value vectors from local tokens
+ mask (`torch.FloatTensor`) of shape [batch_size, padded_seq_len]: attention mask
+ dim (DimensionInfo): DimensionInfo wrapper for dimensions
+
+ Returns:
+ output of shape `[batch_sizes, length, features]`. where length will be padded to a multiple of block_size
+ """
+ # [batch_size, num_heads, global_len+padded_seq_len, dim_per_head]
+ global_and_local_k = torch.cat([global_k, local_k], dim=2)
+ # [batch_size, num_heads, global_len+padded_seq_len, dim_per_head]
+ global_and_local_v = torch.cat([global_v, local_v], dim=2)
+
+ # [batch_size, global_len+padded_seq_len]
+ extended_mask = nn.functional.pad(mask, pad=(dim.global_len, 0), value=0)
+
+ # [batch_size, num_heads, global_len, global_len+padded_seq_len]
+ attn_weights = torch.einsum("BHGF,BHXF->BHGX", global_q, global_and_local_k)
+ attn_weights = attn_weights + extended_mask[:, None, None, :]
+ attn_probs = nn.functional.softmax(attn_weights, dim=-1)
+ attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training)
+
+ # [batch_size, num_heads, global_len, F]
+ attn_output = torch.einsum("BHGX,BHXF->BHGF", attn_probs, global_and_local_v)
+ return attn_output, attn_probs
+
+ def compute_local_attention_representations(
+ self, global_k, global_v, local_q, local_k, local_v, mask, dim: DimensionInfo
+ ):
+ """Compute attention representations for local tokens.
+
+ Local tokens will attend to both global tokens as well as all other tokens within the same local block. Hence,
+ we need to tile and concatenate the global tokens to every local block
+
+ Args:
+ global_k (`torch.FloatTensor`) of shape [batch_size, num_heads, global_len, dim_per_head]:
+ key vectors from global tokens
+ global_v (`torch.FloatTensor`) of shape [batch_size, num_heads, global_len, dim_per_head]:
+ value vectors from global tokens
+ local_q (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]:
+ query vectors from local tokens
+ local_k (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]:
+ key vectors from local tokens
+ local_v (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]:
+ value vectors from local tokens
+ mask (`torch.FloatTensor`) of shape [batch_size, padded_seq_len]: attention mask
+ dim (DimensionInfo): DimensionInfo wrapper for dimensions
+
+ Returns:
+ output of shape `[batch_sizes, length, features]`. where length will be padded to a multiple of block_size
+ """
+ # [batch_size, num_heads, num_blocks, block_size, dim_per_head]
+ blocked_local_q = local_q.view(dim.batch_size, dim.num_heads, dim.num_blocks, dim.block_size, dim.dim_per_head)
+ # [batch_size, num_heads, num_blocks, block_size, dim_per_head]
+ blocked_local_k = local_k.view(dim.batch_size, dim.num_heads, dim.num_blocks, dim.block_size, dim.dim_per_head)
+ # [batch_size, num_heads, num_blocks, block_size, dim_per_head]
+ blocked_local_v = local_v.view(dim.batch_size, dim.num_heads, dim.num_blocks, dim.block_size, dim.dim_per_head)
+
+ # [batch_size, num_blocks, global_len+block_size]
+ extended_mask = nn.functional.pad(
+ mask.view(dim.batch_size, dim.num_blocks, dim.block_size),
+ pad=(dim.global_len, 0),
+ value=0,
+ )
+
+ # [batch_size, num_heads, num_blocks, block_size, global_len]
+ blocked_local2global = torch.einsum("BHNKF,BHGF->BHNKG", blocked_local_q, global_k)
+ # [batch_size, num_heads, num_blocks, block_size, block_size]
+ blocked_local2local = torch.einsum("BHNKF,BHNXF->BHNKX", blocked_local_q, blocked_local_k)
+
+ # [batch_size, num_heads, num_blocks, block_size, global_len+block_size]
+ attn_weights = torch.cat([blocked_local2global, blocked_local2local], dim=-1)
+ attn_weights = attn_weights + extended_mask[:, None, :, None, :]
+ attn_probs = nn.functional.softmax(attn_weights, dim=-1)
+ attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training)
+
+ # [batch_size, num_heads, num_blocks, block_size, global_len]
+ local2global_attn_probs = attn_probs[:, :, :, :, : dim.global_len]
+ # [batch_size, num_heads, num_blocks, block_size, block_size]
+ local2local_attn_probs = attn_probs[:, :, :, :, dim.global_len :]
+
+ # [batch_size, num_heads, num_blocks, block_size, dim_per_head]
+ local2global_attn_output = torch.einsum("BHNKG,BHGF->BHNKF", local2global_attn_probs, global_v)
+ # [batch_size, num_heads, num_blocks, block_size, dim_per_head]
+ local2local_attn_output = torch.einsum("BHNKX,BHNXF->BHNKF", local2local_attn_probs, blocked_local_v)
+ # [batch_size, num_heads, num_blocks, block_size, dim_per_head]
+ attn_output = local2global_attn_output + local2local_attn_output
+ return attn_output, attn_probs
+
+
+class PegasusXEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, stagger_blocks_this_layer: bool, config: PegasusXConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+ self.self_attn = PegasusXGlobalLocalAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ block_size=config.block_size,
+ dropout=config.attention_dropout,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.global_self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.stagger_blocks_this_layer = stagger_blocks_this_layer
+ self.block_size = config.block_size
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ global_hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ output_attentions: bool = False,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
+ global_hidden_states (`torch.FloatTensor`): global token hidden states
+ *(seq_len, num_global_tokens, embed_dim)*
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ global_residual = global_hidden_states
+
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ global_hidden_states = self.global_self_attn_layer_norm(global_hidden_states)
+
+ if self.stagger_blocks_this_layer:
+ # Pad the blocks to simulate staggering
+ hidden_states, attention_mask = self.pad_local_tokens(
+ hidden_states=hidden_states, attention_mask=attention_mask, block_size=self.block_size
+ )
+
+ hidden_states, global_hidden_states, attn_weights = self.self_attn(
+ token_hidden_states=hidden_states,
+ global_hidden_states=global_hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ if self.stagger_blocks_this_layer:
+ # Undo the padding
+ hidden_states = self.unpad_local_tokens(padded_hidden_states=hidden_states, block_size=self.block_size)
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ global_hidden_states = nn.functional.dropout(global_hidden_states, p=self.dropout, training=self.training)
+ global_hidden_states = global_residual + global_hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ global_residual = global_hidden_states
+ global_hidden_states = self.final_layer_norm(global_hidden_states)
+ global_hidden_states = self.activation_fn(self.fc1(global_hidden_states))
+ global_hidden_states = nn.functional.dropout(
+ global_hidden_states, p=self.activation_dropout, training=self.training
+ )
+ global_hidden_states = self.fc2(global_hidden_states)
+ global_hidden_states = nn.functional.dropout(global_hidden_states, p=self.dropout, training=self.training)
+ global_hidden_states = global_residual + global_hidden_states
+ outputs = (hidden_states, global_hidden_states)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+ @classmethod
+ def pad_local_tokens(cls, hidden_states, attention_mask, block_size):
+ # hidden_states: [batch_size, seq_len, hidden_dim]
+ pad_size = block_size // 2
+ mask_min_value = torch.finfo(hidden_states.dtype).min
+ padded_hidden_states = torch.nn.functional.pad(
+ hidden_states,
+ pad=(0, 0, pad_size, pad_size),
+ )
+ padded_mask = torch.nn.functional.pad(
+ attention_mask,
+ pad=(pad_size, pad_size),
+ value=mask_min_value,
+ )
+ return padded_hidden_states, padded_mask
+
+ @classmethod
+ def unpad_local_tokens(cls, padded_hidden_states, block_size):
+ # padded_hidden_states: [batch_size, padded seq_len, hidden_dim]
+ pad_size = block_size // 2
+ return padded_hidden_states[:, pad_size:-pad_size, :]
+
+
+class PegasusXDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: PegasusXConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = PegasusXAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ bias=False,
+ config=config,
+ layer_idx=layer_idx,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.encoder_attn = PegasusXAttention(
+ self.embed_dim,
+ config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ bias=False,
+ config=config,
+ layer_idx=layer_idx,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape *(seq_len, batch, embed_dim)*
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
+ past_key_values (`Cache`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache: Whether to us KV cache for decoding
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
+ cache in the correct position and to infer the complete sequence length.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Cross-Attention Block
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ hidden_states, cross_attn_weights = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+ return outputs
+
+
+@auto_docstring
+class PegasusXPreTrainedModel(PreTrainedModel):
+ config: PegasusXConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = [r"PegasusXEncoderLayer", r"PegasusXDecoderLayer"]
+ _supports_flash_attn = True
+ # Flaky logits
+ _supports_sdpa = False
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ ):
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
+ return attention_mask
+
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
+
+class PegasusXEncoder(PegasusXPreTrainedModel):
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`PegasusXEncoderLayer`].
+
+ Args:
+ config: PegasusXConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] = None):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+ self.layerdrop = config.encoder_layerdrop
+
+ embed_dim = config.d_model
+ padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_position_embeddings
+ embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+
+ if embed_tokens is not None:
+ self.embed_tokens = embed_tokens
+ else:
+ self.embed_tokens = PegasusXScaledWordEmbedding(
+ config.vocab_size, embed_dim, padding_idx, embed_scale=embed_scale
+ )
+
+ self.embed_global = nn.Embedding(config.num_global_tokens, embed_dim)
+ self.embed_positions = PegasusXSinusoidalPositionalEmbedding(embed_dim)
+ self.layers = nn.ModuleList(
+ [
+ PegasusXEncoderLayer(
+ stagger_blocks_this_layer=i % 2 == 1 and config.stagger_local_blocks, config=config
+ )
+ for i in range(config.encoder_layers)
+ ]
+ )
+ self.layer_norm = nn.LayerNorm(config.d_model)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+ config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embeddings. If position embeddings are learned, increasing the size will add
+ newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+ position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+ add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+ will remove vectors from the end.
+ """
+ logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
+ self.config.max_position_embeddings = new_num_position_embeddings
+
+ self.embed_positions = PegasusXSinusoidalPositionalEmbedding(self.config.d_model)
+ self.embed_positions.to(self.device)
+
+ def get_position_embeddings(self) -> nn.Embedding:
+ """
+ Returns the position embeddings matrix
+ """
+ return self.embed_positions
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ embed_pos = self.embed_positions(inputs_embeds)
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ batch_size, seq_len, _ = hidden_states.shape
+
+ # Setup mask
+ if attention_mask is None:
+ attention_mask = torch.ones(*input_shape, dtype=inputs_embeds.dtype, device=inputs_embeds.device)
+ attention_mask = attention_mask.to(dtype=hidden_states.dtype)
+ mask_min_value = torch.finfo(hidden_states.dtype).min
+ inverted_mask = 1.0 - attention_mask
+ attention_mask = inverted_mask.masked_fill(
+ inverted_mask.to(torch.bool),
+ mask_min_value,
+ )
+
+ # padding to block_size
+ if seq_len % self.config.block_size != 0:
+ pad_len = self.config.block_size - seq_len % self.config.block_size
+ hidden_states = nn.functional.pad(hidden_states, pad=(0, 0, 0, pad_len), value=0)
+ attention_mask = nn.functional.pad(attention_mask, pad=(0, pad_len), value=mask_min_value)
+
+ # Global tokens
+ global_hidden_states = self.embed_global(
+ torch.arange(self.config.num_global_tokens, device=hidden_states.device)[None].expand(batch_size, -1)
+ )
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ to_drop = False
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop: # skip the layer
+ to_drop = True
+
+ if to_drop:
+ layer_outputs = (None, None)
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ global_hidden_states,
+ attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+ global_hidden_states = layer_outputs[1]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[2],)
+
+ # Undo padding-to-block-size
+ hidden_states = hidden_states[:, :seq_len]
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + ((hidden_states, global_hidden_states),)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class PegasusXDecoder(PegasusXPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PegasusDecoderLayer`]
+
+ Args:
+ config: PegasusXConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] = None):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.decoder_layerdrop
+ self.max_target_positions = config.max_position_embeddings
+ embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+ padding_idx = config.pad_token_id
+
+ if embed_tokens is not None:
+ self.embed_tokens = embed_tokens
+ else:
+ self.embed_tokens = PegasusXScaledWordEmbedding(
+ config.vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale
+ )
+
+ self.embed_positions = PegasusXSinusoidalPositionalEmbedding(config.d_model)
+ self.layers = nn.ModuleList([PegasusXDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
+ self.layer_norm = nn.LayerNorm(config.d_model)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ inputs_embeds=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ cache_position=None,
+ ):
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
+ embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
+ cache in the correct position and to infer the complete sequence length.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # initialize `past_key_values`
+ if use_cache and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+ if use_cache and isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
+
+ batch_size, seq_length = inputs_embeds.size()[:-1]
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
+ )
+
+ if attention_mask is None and not is_torchdynamo_compiling():
+ # required mask seq length can be calculated via length of past cache
+ mask_seq_length = past_key_values_length + seq_length
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
+
+ self_attn_cache = (
+ past_key_values.self_attention_cache
+ if isinstance(past_key_values, EncoderDecoderCache)
+ else past_key_values
+ )
+
+ causal_mask = self._update_causal_mask(
+ attention_mask,
+ inputs_embeds,
+ cache_position,
+ self_attn_cache,
+ )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
+ )
+
+ # embed positions
+ position_ids = cache_position.unsqueeze(1)
+ position_ids = self.embed_positions(inputs_embeds, past_key_values_length, position_ids)
+ position_ids = position_ids.to(inputs_embeds.device)
+ hidden_states = inputs_embeds + position_ids
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ causal_mask,
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@auto_docstring
+class PegasusXModel(PegasusXPreTrainedModel):
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
+
+ def __init__(self, config: PegasusXConfig):
+ super().__init__(config)
+
+ vocab_size = config.vocab_size
+ embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+ padding_idx = config.pad_token_id
+ self.shared = PegasusXScaledWordEmbedding(
+ vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale
+ )
+
+ self.encoder = PegasusXEncoder(config, self.shared)
+ self.decoder = PegasusXDecoder(config, self.shared)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, value):
+ self.shared = value
+ self.encoder.embed_tokens = self.shared
+ self.decoder.embed_tokens = self.shared
+
+ def get_encoder(self):
+ return self.encoder
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+ config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embeddings. If position embeddings are learned, increasing the size will add
+ newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+ position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+ add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+ will remove vectors from the end.
+ """
+ self.config.max_position_embeddings = new_num_position_embeddings
+ self.encoder.resize_position_embeddings(new_num_position_embeddings)
+ self.decoder.resize_position_embeddings(new_num_position_embeddings)
+
+ def get_position_embeddings(self) -> tuple[nn.Embedding]:
+ """
+ Returns the position embeddings matrix
+ """
+ return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings())
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.Tensor] = None,
+ decoder_attention_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[tuple[torch.FloatTensor]] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, Seq2SeqModelOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ PEGASUS-X uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, PegasusModel
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-x-large")
+ >>> model = PegasusModel.from_pretrained("google/pegasus-x-large")
+
+ >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt")
+ >>> decoder_inputs = tokenizer("Studies show that", return_tensors="pt")
+ >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_inputs.input_ids)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ >>> list(last_hidden_states.shape)
+ [1, 4, 1024]
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The PEGASUS-X for conditional generation (e.g. summarization).
+ """
+)
+class PegasusXForConditionalGeneration(PegasusXPreTrainedModel, GenerationMixin):
+ base_model_prefix = "model"
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
+
+ def __init__(self, config: PegasusXConfig):
+ super().__init__(config)
+ self.model = PegasusXModel(config)
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.model.get_encoder()
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+ config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embeddings. If position embeddings are learned, increasing the size will add
+ newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+ position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+ add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+ will remove vectors from the end.
+ """
+ self.config.max_position_embeddings = new_num_position_embeddings
+ self.model.encoder.resize_position_embeddings(new_num_position_embeddings)
+ self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
+
+ def get_position_embeddings(self) -> tuple[nn.Embedding]:
+ """
+ Returns the position embeddings matrix
+ """
+ return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings())
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.Tensor] = None,
+ decoder_attention_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[tuple[torch.FloatTensor]] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, Seq2SeqLMOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ PEGASUS-X uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None:
+ if use_cache:
+ logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
+ use_cache = False
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ lm_logits = self.lm_head(outputs[0])
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=masked_lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->PegasusX
+class PegasusXDecoderWrapper(PegasusXPreTrainedModel):
+ """
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
+ used in combination with the [`EncoderDecoderModel`] framework.
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.decoder = PegasusXDecoder(config)
+
+ def forward(self, *args, **kwargs):
+ return self.decoder(*args, **kwargs)
+
+
+__all__ = ["PegasusXForConditionalGeneration", "PegasusXModel", "PegasusXPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..81c3ba93bcf470394a99309dbf2f9869cd68d887
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_perception_lm import *
+ from .image_processing_perception_lm_fast import *
+ from .modeling_perception_lm import *
+ from .processing_perception_lm import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/configuration_perception_lm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/configuration_perception_lm.py
new file mode 100644
index 0000000000000000000000000000000000000000..08c084065ff86a519007afc9d6a1d97eacf6d381
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/configuration_perception_lm.py
@@ -0,0 +1,88 @@
+# coding=utf-8
+# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PerceptionLM model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+from ..timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class PerceptionLMConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PerceptionLMForConditionalGeneration`]. It is used to instantiate an
+ PerceptionLM model according to the specified arguments, defining the model architecture.
+
+ Example models:
+ - [facebook/Perception-LM-1B](https://huggingface.co/facebook/Perception-LM-1B).
+ - [facebook/Perception-LM-3B](https://huggingface.co/facebook/Perception-LM-3B).
+ - [facebook/Perception-LM-8B](https://huggingface.co/facebook/Perception-LM-8B).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`Union[TimmWrapperConfig, dict]`, *optional*, defaults to `TimmWrapperConfig()`):
+ The config object or dictionary of the vision backbone.
+ text_config (`Union[PretrainedConfig, dict]`, *optional*, defaults to `LlamaConfig()`):
+ The config object or dictionary of the text backbone.
+ vision_use_cls_token (`bool`, *optional*, defaults to `True`):
+ Whether CLS token is used in the vision backbone. If used, we remove CLS token embedding from vision output.
+ projector_pooling_ratio (`int`, *optional*, defaults to 1):
+ The pooling ratio used in the multimodal projector.
+ image_token_id (`int`, *optional*, defaults to 128002):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 128003):
+ The video token index to encode the video prompt.
+ """
+
+ model_type = "perception_lm"
+ sub_configs = {"text_config": AutoConfig, "vision_config": TimmWrapperConfig}
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ vision_use_cls_token=True,
+ projector_pooling_ratio=1,
+ image_token_id=128002,
+ video_token_id=128003,
+ **kwargs,
+ ):
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ if isinstance(vision_config, dict):
+ vision_config = TimmWrapperConfig(**vision_config)
+ elif isinstance(vision_config, TimmWrapperConfig):
+ pass
+ elif vision_config is None:
+ vision_config = TimmWrapperConfig()
+ self.vision_config = vision_config
+ self.vision_use_cls_token = vision_use_cls_token
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "llama")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ text_config = CONFIG_MAPPING["llama"]()
+
+ self.text_config = text_config
+ self.projector_pooling_ratio = projector_pooling_ratio
+ super().__init__(**kwargs)
+
+
+__all__ = ["PerceptionLMConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/image_processing_perception_lm_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/image_processing_perception_lm_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..c26132a484397067ad7354ea50705e493f063ff1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/image_processing_perception_lm_fast.py
@@ -0,0 +1,309 @@
+# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for PerceptionLM."""
+
+import math
+from functools import reduce
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from torchvision.transforms import functional as F
+
+from ...image_processing_utils import (
+ BatchFeature,
+)
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ get_image_size,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ PILImageResampling,
+)
+from ...processing_utils import Unpack
+from ...utils import TensorType, auto_docstring
+
+
+class PerceptionLMFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ r"""
+ vision_input_type (`str`, *optional*, defaults to `"thumb+tile"`):
+ Vision processing strategy. `"thumb+tile"` uses both thumbnails and multiple tiles for
+ multi-scale processing, otherwise uses single tile for lower memory usage.
+ tile_size (`int`, *optional*, defaults to `448`):
+ Height and width dimension (in pixels) of each tile used for image processing.
+ max_num_tiles (`int`, *optional*, defaults to `36`):
+ Maximum number of tiles an image can be split into based on its aspect ratio.
+ """
+
+ vision_input_type: str = "thumb+tile"
+ tile_size: int = 448
+ max_num_tiles: int = 36
+
+
+@auto_docstring
+class PerceptionLMImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ do_resize = True
+ do_center_crop = False
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+ size = {"width": 448, "height": 448} # for backward compatibility in tests
+ valid_kwargs = PerceptionLMFastImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[PerceptionLMFastImageProcessorKwargs]) -> None:
+ super().__init__(**kwargs)
+
+ @auto_docstring
+ def preprocess(self, images, **kwargs: Unpack[PerceptionLMFastImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ @staticmethod
+ def _factors(n: int):
+ """Return all factors of a number."""
+ return set(
+ reduce(
+ list.__add__,
+ ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0),
+ )
+ )
+
+ def _find_supported_aspect_ratios(self):
+ """
+ This function computes all the allowed aspect ratios for a fixed
+ number of input chunks. The order of returned items matters for the result of `_fit_image_to_canvas` function.
+ If tie exists in `_fit_image_to_canvas`, the latter in `_find_supported_aspect_ratios` wins.
+
+ For example, with `num_tiles=5`, it will return:
+ {
+ 0.2: [(1, 5)],
+ 5.0: [(5, 1)],
+ 0.25: [(1, 4)],
+ 1.0: [(2, 2), (1, 1)],
+ 4.0: [(4, 1)],
+ 0.3333333333333333: [(1, 3)],
+ 3.0: [(3, 1)],
+ 0.5: [(1, 2)],
+ 2.0: [(2, 1)]
+ }
+ """
+ asp_dict = {}
+ for chunk_size in range(self.max_num_tiles, 0, -1):
+ _factors = sorted(self._factors(chunk_size))
+ _asp_ratios = [(x, chunk_size // x) for x in _factors]
+ for ratio in _asp_ratios:
+ k = ratio[0] / ratio[1]
+ if k not in asp_dict:
+ asp_dict[k] = [ratio]
+ else:
+ asp_dict[k].append(ratio)
+ return asp_dict
+
+ def _get_image_height_width(
+ self, image_width: int, image_height: int, target_width: int, target_height: int
+ ) -> tuple[int, int]:
+ """
+ Given image width, height and target width, height for the canvas, return the dimensions of how the image would be resized
+ with aspect ratio preservation.
+ """
+ scale = image_width / image_height
+
+ if scale > 1.0:
+ # Width is larger than height
+
+ # Rescaling factor is the minimum of the two scaling factors. Else one side would be outside of the canvas.
+ rescaling_factor = min(target_width / image_width, target_height / image_height)
+
+ # Set new width to target width and height to the rescaled height.
+ new_w = rescaling_factor * image_width
+ new_h = math.floor(new_w / scale)
+
+ else:
+ # Height is larger than width
+
+ # Rescaling factor is the minimum of the two scaling factors. Else one side would be outside of the canvas.
+ rescaling_factor = min(target_width / image_width, target_height / image_height)
+
+ # Set new height to target height and width to the rescaled width.
+ new_h = rescaling_factor * image_height
+ new_w = math.floor(new_h * scale)
+
+ return new_w, new_h
+
+ def _fit_image_to_canvas(self, img_width: int, img_height: int, tile_size: int):
+ """
+ Given an image width, height and target number of chunks this function will see if the image
+ can be fit into any of the canvases that can be build from arranging the tiles in a grid.
+ If the image can be fit onto several canvases, it will return the canvas where the shorter edge
+ of the image will be largest.
+ """
+ # Initialize the optimal canvas to None. If no canvas is found where image fits, function returns None.
+ optimal_canvas = None
+ optimal_image_width_height = None
+
+ scale = img_width / img_height
+
+ # Gather all potential supported image resolutions and iterate through them to find best match
+ potential_arrangements = [
+ item for sublist in self._find_supported_aspect_ratios().values() for item in sublist
+ ]
+ for n_w, n_h in potential_arrangements:
+ # Compute the canvas size
+ canvas_width, canvas_height = n_w * tile_size, n_h * tile_size
+
+ # Check if image can fit into the canvas without downsampling
+ if canvas_width >= img_width and canvas_height >= img_height:
+ # If we did not find a good canvas yet, we will use the current one
+ if optimal_canvas is None:
+ # Set optimal canvas and determine the actual image height and width in the canvas with aspect ratio preserving resampling
+ optimal_canvas = (n_w, n_h)
+ optimal_image_width_height = self._get_image_height_width(
+ image_width=img_width,
+ image_height=img_height,
+ target_width=n_w * tile_size,
+ target_height=n_h * tile_size,
+ )
+ else:
+ # If we already found an optimal canvas before, we will check if the shorter edge of the image will be larger than the current optimal canvas.
+ # This means we can potentially upsample the image resolution which is beneficial to performance.
+ image_width_height = self._get_image_height_width(
+ image_width=img_width,
+ image_height=img_height,
+ target_width=n_w * tile_size,
+ target_height=n_h * tile_size,
+ )
+ # Llama3V dynamic tiling. Prioritize biggest canvas.
+ if (scale < 1.0 and (image_width_height[0] >= optimal_image_width_height[0])) or (
+ scale >= 1.0 and (image_width_height[1] >= optimal_image_width_height[1])
+ ):
+ optimal_canvas = (n_w, n_h)
+ optimal_image_width_height = image_width_height
+ return optimal_canvas
+
+ def _find_closest_aspect_ratio(self, img_width: int, img_height: int, tile_size: int) -> tuple:
+ """
+ Given an image width, height and target number of chunks
+ this function will find the closest supported aspect ratio.
+ """
+ target_aspect_ratio = img_width / img_height
+ asp_dict = self._find_supported_aspect_ratios()
+ closest_aspect_ratio = None
+ if target_aspect_ratio >= 1:
+ closest_aspect_ratio = min(
+ [k for k in asp_dict if k <= target_aspect_ratio],
+ key=lambda x: abs(x - target_aspect_ratio),
+ )
+ tiles_given_aspect_ratio = asp_dict[closest_aspect_ratio]
+ # select largest width
+ return max(tiles_given_aspect_ratio, key=lambda x: x[0])
+ else:
+ closest_aspect_ratio = min(
+ [k for k in asp_dict if k > target_aspect_ratio],
+ key=lambda x: abs(1 / x - 1 / target_aspect_ratio),
+ )
+ tiles_given_aspect_ratio = asp_dict[closest_aspect_ratio]
+ # select largest height
+ return max(tiles_given_aspect_ratio, key=lambda x: x[1])
+
+ def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
+ # Split image into number of required tiles (width x height)
+ batch_size, num_channels, height, width = image.size()
+ image = image.view(batch_size, num_channels, nch, height // nch, ncw, width // ncw)
+ # Permute dimensions to reorder the axes
+ image = image.permute(0, 2, 4, 1, 3, 5).contiguous()
+ # Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
+ image = image.view(batch_size, ncw * nch, num_channels, height // nch, width // ncw)
+ return image
+
+ def resize(
+ self,
+ image: np.ndarray,
+ tile_size: int,
+ max_num_tiles: int,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ height, width = get_image_size(image, channel_dim=input_data_format)
+ if max_num_tiles > 1:
+ aspect_ratio = self._fit_image_to_canvas(img_width=width, img_height=height, tile_size=tile_size)
+ if aspect_ratio is None:
+ # If we did not find a canvas, we have to find the closest aspect ratio and downsample the image
+ aspect_ratio = self._find_closest_aspect_ratio(img_width=width, img_height=height, tile_size=tile_size)
+ else:
+ aspect_ratio = (1, 1)
+ new_width, new_height = aspect_ratio[0] * tile_size, aspect_ratio[1] * tile_size
+ image = F.resize(image, (new_height, new_width), interpolation=resample)
+ return image, aspect_ratio
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ do_rescale: Optional[bool],
+ rescale_factor: Optional[Union[int, float]],
+ do_normalize: Optional[bool],
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ vision_input_type: str,
+ tile_size: int,
+ max_num_tiles: int,
+ return_tensors: Optional[Union[str, TensorType]],
+ disable_grouping: bool,
+ **kwargs: Unpack[PerceptionLMFastImageProcessorKwargs],
+ ) -> BatchFeature:
+ # Group images by size for batched transformation
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ if vision_input_type == "thumb+tile":
+ thumbnails, _ = self.resize(stacked_images, tile_size, max_num_tiles=1)
+ images_for_tiling, (tiles_w, tiles_h) = self.resize(
+ stacked_images, tile_size, max_num_tiles=max_num_tiles
+ )
+ image_tiles = self._split(images_for_tiling, tiles_w, tiles_h)
+ stacked_images = torch.cat([thumbnails.unsqueeze(1), image_tiles], dim=1)
+ else: # vanilla single tile for low memory devices
+ stacked_images, _ = self.resize(stacked_images, tile_size, max_num_tiles=1)
+
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images,
+ do_rescale,
+ rescale_factor,
+ do_normalize,
+ image_mean,
+ image_std,
+ )
+ processed_images_grouped[shape] = stacked_images
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = [p[None] if p.ndim == 3 else p for p in processed_images] # add tiles dimension if needed
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+
+__all__ = ["PerceptionLMImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/modeling_perception_lm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/modeling_perception_lm.py
new file mode 100644
index 0000000000000000000000000000000000000000..074e91e14e88d29296e40f8a3f5c94663221997d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/modeling_perception_lm.py
@@ -0,0 +1,487 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/perception_lm/modular_perception_lm.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_perception_lm.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, can_return_tuple
+from ..auto import AutoModel
+from .configuration_perception_lm import PerceptionLMConfig
+
+
+class PerceptionLMAdaptiveAvgPooling(nn.Module):
+ def __init__(self, pooling_ratio=2):
+ super().__init__()
+ self.pooling_ratio = pooling_ratio
+
+ def forward(self, hidden_states):
+ b, num_tokens, c = hidden_states.shape
+ h = int(math.sqrt(num_tokens))
+ if h * h != num_tokens:
+ raise ValueError(f"num_tokens {num_tokens} is expected to be a square number")
+
+ shape = (h // self.pooling_ratio, h // self.pooling_ratio)
+ hidden_states = hidden_states.permute(0, 2, 1).reshape(b, -1, h, h)
+ hidden_states = F.adaptive_avg_pool2d(hidden_states, shape)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ return hidden_states
+
+
+class PerceptionLMMultiModalProjector(nn.Module):
+ def __init__(self, config: PerceptionLMConfig):
+ super().__init__()
+ input_size = config.vision_config.model_args["embed_dim"]
+ output_size = config.text_config.hidden_size
+ self.linear_1 = nn.Linear(
+ in_features=input_size,
+ out_features=output_size,
+ bias=True,
+ )
+ self.gelu = nn.GELU()
+ self.linear_2 = nn.Linear(
+ in_features=output_size,
+ out_features=output_size,
+ bias=True,
+ )
+ self.pooling = (
+ PerceptionLMAdaptiveAvgPooling(config.projector_pooling_ratio)
+ if config.projector_pooling_ratio > 1
+ else nn.Identity()
+ )
+
+ def forward(self, features):
+ features = features.permute(1, 0, 2) # NLD -> LND
+ features = self.linear_1(features)
+ features = self.gelu(features)
+ features = self.linear_2(features)
+ features = features.permute(1, 0, 2) # LND -> NLD
+ features = self.pooling(features)
+ return features
+
+
+@auto_docstring
+class PerceptionLMPreTrainedModel(PreTrainedModel):
+ config: PerceptionLMConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _skip_keys_device_placement = "past_key_values"
+
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for PerceptionLM outputs, with hidden states and attentions.
+ """
+)
+class PerceptionLMModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ Image hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ video_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_videos, sequence_length, hidden_size)`.
+ Video hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+ video_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for PerceptionLM causal language model (or autoregressive) outputs.
+ """
+)
+class PerceptionLMCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ Image hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ video_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_videos, sequence_length, hidden_size)`.
+ Video hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+ video_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@auto_docstring
+class PerceptionLMModel(PerceptionLMPreTrainedModel):
+ _checkpoint_conversion_mapping = {}
+
+ def __init__(self, config: PerceptionLMConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config.vision_config)
+ self.multi_modal_projector = PerceptionLMMultiModalProjector(config)
+ self.language_model = AutoModel.from_config(config.text_config)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ **kwargs,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_tiles, channels, height, width)`)
+ The tensors corresponding to the input images.
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_tiles, num_patches, embed_dim)`).
+ """
+ image_outputs = self.vision_tower(pixel_values.flatten(0, 1))
+ image_outputs = image_outputs.last_hidden_state
+ if self.config.vision_use_cls_token:
+ image_outputs = image_outputs[:, 1:, :]
+ image_features = self.multi_modal_projector(image_outputs)
+ return image_features
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: Optional[torch.FloatTensor] = None,
+ video_features: Optional[torch.FloatTensor] = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.video_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.size()[:-1].numel()}"
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
+ raise ValueError(
+ f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.size()[:-1].numel()}"
+ )
+
+ return special_image_mask, special_video_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **lm_kwargs,
+ ) -> Union[tuple, PerceptionLMModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+ if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both (pixel_values or pixel_values_videos) and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ image_features = None
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values=pixel_values)
+ image_features = image_features.to(inputs_embeds.device, dtype=inputs_embeds.dtype)
+ special_image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ video_features = None
+ if pixel_values_videos is not None:
+ video_features = self.get_image_features(pixel_values=pixel_values_videos)
+ video_features = video_features.to(inputs_embeds.device, dtype=inputs_embeds.dtype)
+ _, special_video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **lm_kwargs,
+ )
+ return PerceptionLMModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ hidden_states=outputs.hidden_states,
+ past_key_values=outputs.past_key_values,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ video_hidden_states=(video_features if pixel_values_videos is not None else None),
+ )
+
+
+@auto_docstring
+class PerceptionLMForConditionalGeneration(PerceptionLMPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: PerceptionLMConfig):
+ super().__init__(config)
+ self.model = PerceptionLMModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **lm_kwargs,
+ ) -> Union[tuple, PerceptionLMCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ from transformers import AutoProcessor, AutoModelForImageTextToText
+ from huggingface_hub import hf_hub_download
+
+ MODEL_PATH = "facebook/Perception-LM-1B"
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
+ model = AutoModelForImageTextToText.from_pretrained(MODEL_PATH).to("cuda")
+ test_image_file = hf_hub_download(
+ repo_id="shumingh/perception_lm_test_images",
+ filename="14496_0.PNG",
+ repo_type="dataset",
+ )
+ conversation = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "url": test_image_file,
+ },
+ {"type": "text", "text": "Describe the bar plot in the image."},
+ ],
+ }
+ ]
+
+ inputs = processor.apply_chat_template(
+ [conversation],
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ )
+ inputs = inputs.to(model.device)
+ generate_ids = model.generate(**inputs, max_new_tokens=256)
+ input_length = inputs["input_ids"].shape[1]
+ generate_ids_without_inputs = generate_ids[:, input_length:]
+
+ for output in processor.batch_decode(generate_ids_without_inputs, skip_special_tokens=True):
+ print(output)
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **lm_kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits,
+ labels=labels,
+ vocab_size=self.config.text_config.vocab_size,
+ **lm_kwargs,
+ )
+
+ return PerceptionLMCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ video_hidden_states=outputs.video_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ pixel_values_videos=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["pixel_values_videos"] = pixel_values_videos
+ return model_inputs
+
+
+__all__ = ["PerceptionLMForConditionalGeneration", "PerceptionLMPreTrainedModel", "PerceptionLMModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/modular_perception_lm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/modular_perception_lm.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b50b824220210bebb65f5c27b206b8d8e101f84
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/modular_perception_lm.py
@@ -0,0 +1,441 @@
+# coding=utf-8
+# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch PerceptionLM model."""
+
+import math
+from typing import Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...cache_utils import Cache
+from ...utils import (
+ auto_docstring,
+ can_return_tuple,
+ logging,
+)
+from ..auto import AutoModel
+from ..llava.modeling_llava import (
+ LlavaCausalLMOutputWithPast,
+ LlavaForConditionalGeneration,
+ LlavaModel,
+ LlavaModelOutputWithPast,
+ LlavaPreTrainedModel,
+)
+from .configuration_perception_lm import PerceptionLMConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class PerceptionLMAdaptiveAvgPooling(nn.Module):
+ def __init__(self, pooling_ratio=2):
+ super().__init__()
+ self.pooling_ratio = pooling_ratio
+
+ def forward(self, hidden_states):
+ b, num_tokens, c = hidden_states.shape
+ h = int(math.sqrt(num_tokens))
+ if h * h != num_tokens:
+ raise ValueError(f"num_tokens {num_tokens} is expected to be a square number")
+
+ shape = (h // self.pooling_ratio, h // self.pooling_ratio)
+ hidden_states = hidden_states.permute(0, 2, 1).reshape(b, -1, h, h)
+ hidden_states = F.adaptive_avg_pool2d(hidden_states, shape)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ return hidden_states
+
+
+class PerceptionLMMultiModalProjector(nn.Module):
+ def __init__(self, config: PerceptionLMConfig):
+ super().__init__()
+ input_size = config.vision_config.model_args["embed_dim"]
+ output_size = config.text_config.hidden_size
+ self.linear_1 = nn.Linear(
+ in_features=input_size,
+ out_features=output_size,
+ bias=True,
+ )
+ self.gelu = nn.GELU()
+ self.linear_2 = nn.Linear(
+ in_features=output_size,
+ out_features=output_size,
+ bias=True,
+ )
+ self.pooling = (
+ PerceptionLMAdaptiveAvgPooling(config.projector_pooling_ratio)
+ if config.projector_pooling_ratio > 1
+ else nn.Identity()
+ )
+
+ def forward(self, features):
+ features = features.permute(1, 0, 2) # NLD -> LND
+ features = self.linear_1(features)
+ features = self.gelu(features)
+ features = self.linear_2(features)
+ features = features.permute(1, 0, 2) # LND -> NLD
+ features = self.pooling(features)
+ return features
+
+
+class PerceptionLMPreTrainedModel(LlavaPreTrainedModel):
+ base_model_prefix = "model"
+
+
+class PerceptionLMModelOutputWithPast(LlavaModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ Image hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ video_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_videos, sequence_length, hidden_size)`.
+ Video hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ video_hidden_states: Optional[torch.FloatTensor] = None
+
+
+class PerceptionLMCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ Image hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ video_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_videos, sequence_length, hidden_size)`.
+ Video hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ video_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@auto_docstring
+class PerceptionLMModel(LlavaModel):
+ _checkpoint_conversion_mapping = {}
+
+ def __init__(self, config: PerceptionLMConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config.vision_config)
+ self.multi_modal_projector = PerceptionLMMultiModalProjector(config)
+ self.language_model = AutoModel.from_config(config.text_config)
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ **kwargs,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_tiles, channels, height, width)`)
+ The tensors corresponding to the input images.
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_tiles, num_patches, embed_dim)`).
+ """
+ image_outputs = self.vision_tower(pixel_values.flatten(0, 1))
+ image_outputs = image_outputs.last_hidden_state
+ if self.config.vision_use_cls_token:
+ image_outputs = image_outputs[:, 1:, :]
+ image_features = self.multi_modal_projector(image_outputs)
+ return image_features
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: Optional[torch.FloatTensor] = None,
+ video_features: Optional[torch.FloatTensor] = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.video_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.size()[:-1].numel()}"
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
+ raise ValueError(
+ f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.size()[:-1].numel()}"
+ )
+
+ return special_image_mask, special_video_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **lm_kwargs,
+ ) -> Union[tuple, PerceptionLMModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+ if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both (pixel_values or pixel_values_videos) and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ image_features = None
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values=pixel_values)
+ image_features = image_features.to(inputs_embeds.device, dtype=inputs_embeds.dtype)
+ special_image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ video_features = None
+ if pixel_values_videos is not None:
+ video_features = self.get_image_features(pixel_values=pixel_values_videos)
+ video_features = video_features.to(inputs_embeds.device, dtype=inputs_embeds.dtype)
+ _, special_video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **lm_kwargs,
+ )
+ return PerceptionLMModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ hidden_states=outputs.hidden_states,
+ past_key_values=outputs.past_key_values,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ video_hidden_states=(video_features if pixel_values_videos is not None else None),
+ )
+
+
+@auto_docstring
+class PerceptionLMForConditionalGeneration(LlavaForConditionalGeneration):
+ _checkpoint_conversion_mapping = {}
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ pixel_values_videos=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["pixel_values_videos"] = pixel_values_videos
+ return model_inputs
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **lm_kwargs,
+ ) -> Union[tuple, PerceptionLMCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ from transformers import AutoProcessor, AutoModelForImageTextToText
+ from huggingface_hub import hf_hub_download
+
+ MODEL_PATH = "facebook/Perception-LM-1B"
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
+ model = AutoModelForImageTextToText.from_pretrained(MODEL_PATH).to("cuda")
+ test_image_file = hf_hub_download(
+ repo_id="shumingh/perception_lm_test_images",
+ filename="14496_0.PNG",
+ repo_type="dataset",
+ )
+ conversation = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "url": test_image_file,
+ },
+ {"type": "text", "text": "Describe the bar plot in the image."},
+ ],
+ }
+ ]
+
+ inputs = processor.apply_chat_template(
+ [conversation],
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ )
+ inputs = inputs.to(model.device)
+ generate_ids = model.generate(**inputs, max_new_tokens=256)
+ input_length = inputs["input_ids"].shape[1]
+ generate_ids_without_inputs = generate_ids[:, input_length:]
+
+ for output in processor.batch_decode(generate_ids_without_inputs, skip_special_tokens=True):
+ print(output)
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **lm_kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits,
+ labels=labels,
+ vocab_size=self.config.text_config.vocab_size,
+ **lm_kwargs,
+ )
+
+ return PerceptionLMCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ video_hidden_states=outputs.video_hidden_states,
+ )
+
+ def get_image_features(self, **kwargs):
+ raise AttributeError("Not needed for PerceptionLM")
+
+ def language_model(self):
+ raise AttributeError("Not needed for PerceptionLM")
+
+ def vision_tower(self):
+ raise AttributeError("Not needed for PerceptionLM")
+
+ def multi_modal_projector(self):
+ raise AttributeError("Not needed for PerceptionLM")
+
+
+__all__ = [
+ "PerceptionLMForConditionalGeneration",
+ "PerceptionLMPreTrainedModel",
+ "PerceptionLMModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/processing_perception_lm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/processing_perception_lm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f61c54554d32c22b4f22ae8905233405c828f3da
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/processing_perception_lm.py
@@ -0,0 +1,244 @@
+# coding=utf-8
+# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for PerceptionLM.
+"""
+
+from collections.abc import Iterable
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput, get_image_size, to_numpy_array
+from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import logging
+from ...video_utils import VideoInput
+
+
+logger = logging.get_logger(__name__)
+
+
+class PerceptionLMProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "return_mm_token_type_ids": False,
+ },
+ }
+
+
+class PerceptionLMProcessor(ProcessorMixin):
+ r"""
+ Constructs a PerceptionLM processor which wraps a PerceptionLM image processor, a PerceptionLM video processor, and a tokenizer into a single processor.
+
+ [`PerceptionLMProcessor`] offers all the functionalities of [`PerceptionLMImageProcessorFast`], [`PerceptionLMVideoProcessor`], and the tokenizer (e.g. [`LlamaTokenizerFast`]). See the
+ [`~PerceptionLMProcessor.__call__`] and [`~PerceptionLMProcessor.decode`] for more information.
+
+ Args:
+ video_processor ([`PerceptionLMVideoProcessor`], *optional*):
+ The video processor to process video inputs.
+ image_processor ([`PerceptionLMImageProcessorFast`], *optional*):
+ The image processor to process image inputs.
+ tokenizer ([`LlamaTokenizerFast`] or similar, *optional*):
+ The tokenizer to process text inputs.
+ patch_size (`int`, *optional*):
+ Patch size from the vision tower.
+ chat_template (`str`, *optional*):
+ A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string.
+ pooling_ratio (`int`, *optional*, defaults to 2):
+ Pooling ratio for vision tokens. If not 1, 2D adaptive pooling is applied over projected vision tokens.
+ """
+
+ attributes = ["video_processor", "image_processor", "tokenizer"]
+ image_processor_class = "AutoImageProcessor"
+ video_processor_class = "AutoVideoProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ video_processor=None,
+ image_processor=None,
+ tokenizer=None,
+ patch_size=None,
+ chat_template=None,
+ pooling_ratio=2,
+ **kwargs,
+ ):
+ self.patch_size = patch_size
+ self.pooling_ratio = pooling_ratio
+ self.image_token = tokenizer.image_token
+ self.video_token = tokenizer.video_token
+ self.image_token_id = tokenizer.image_token_id
+ self.video_token_id = tokenizer.video_token_id
+ super().__init__(video_processor, image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ audio=None,
+ videos: Optional[VideoInput] = None,
+ **kwargs: Unpack[PerceptionLMProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Prepares a batch containing one or more sequences of text and/or images and/or videos.
+
+ If `text` is provided, it is tokenized using the tokenizer.
+ If `images` is provided, they are processed using the image processor.
+ If `videos` is provided, they are processed using the video processor.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
+ The image or batch of images to be processed. Each image can be a PIL image, NumPy array, or PyTorch tensor.
+ Both channels-first and channels-last formats are supported.
+ text (`str`, `List[str]`, *optional*):
+ The sequence or batch of sequences to be tokenized. Each sequence can be a string.
+ videos (`Any`, *optional*):
+ The video or batch of videos to be processed.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is provided.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is provided).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is provided.
+ - **pixel_values_videos** -- Video pixel values to be fed to a model. Returned when `videos` is provided.
+ """
+ if text is None:
+ raise ValueError(
+ "You have to specify at least `text` input. Optionally, you can also specify `images` or `videos`."
+ )
+
+ output_kwargs = self._merge_kwargs(
+ PerceptionLMProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ if images is not None:
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ else:
+ image_inputs = {}
+
+ if videos is not None:
+ videos_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
+ else:
+ videos_inputs = {}
+
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list) and not isinstance(text[0], str):
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+
+ # try to expand inputs in processing if we have the necessary parts
+ prompt_strings = []
+
+ pixel_values = iter(image_inputs.get("pixel_values", []))
+ pixel_values_videos = iter(videos_inputs.get("pixel_values_videos", []))
+ for sample in text:
+ # Replace the media token with the expanded media token sequence
+ sample = self._expand_media_tokens(sample, self.tokenizer.image_token, pixel_values)
+ sample = self._expand_media_tokens(sample, self.tokenizer.video_token, pixel_values_videos)
+ prompt_strings.append(sample)
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
+ self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image", "video"])
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
+
+ def _expand_media_tokens(self, sample, media_token: str, media_iter: Iterable):
+ media_count = sample.count(media_token)
+ if media_count > 0:
+ media_list = [next(media_iter) for _ in range(media_count)]
+ sample_splits = sample.split(media_token)
+ media_token_list = []
+ for media in media_list:
+ height, width = get_image_size(to_numpy_array(media))
+ num_tiles = media.shape[0]
+ num_media_tokens = (
+ (height // self.patch_size // self.pooling_ratio)
+ * (width // self.patch_size // self.pooling_ratio)
+ * num_tiles
+ )
+ media_token_list.append(num_media_tokens)
+ sample = ""
+ for i, num_media_tokens in enumerate(media_token_list):
+ sample += sample_splits[i]
+ sample += media_token * num_media_tokens
+ sample += sample_splits[-1]
+ return sample
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+
+ Args:
+ image_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (height, width) per each image.
+
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = PerceptionLMProcessorKwargs._defaults.get("images_kwargs", {})
+ images_kwargs.update(kwargs)
+ tile_size = images_kwargs.get("tile_size", None) or self.image_processor.tile_size
+ vision_input_type = images_kwargs.get("vision_input_type", None) or self.image_processor.vision_input_type
+
+ num_image_tokens = []
+ num_image_patches = []
+ for height, width in image_sizes:
+ if vision_input_type == "thumb+tile":
+ aspect_ratio = self.image_processor._fit_image_to_canvas(
+ img_width=width, img_height=height, tile_size=tile_size
+ )
+ if aspect_ratio is None:
+ aspect_ratio = self.image_processor._find_closest_aspect_ratio(
+ img_width=width, img_height=height, tile_size=tile_size
+ )
+ num_tiles = aspect_ratio[0] * aspect_ratio[1] + 1 # base image and tiles
+ else:
+ num_tiles = 1
+
+ num_image_tokens.append(
+ (tile_size // self.patch_size // self.pooling_ratio)
+ * (tile_size // self.patch_size // self.pooling_ratio)
+ * num_tiles
+ )
+ num_image_patches.append(num_tiles)
+
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+ return MultiModalData(**vision_data)
+
+
+__all__ = ["PerceptionLMProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/video_processing_perception_lm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/video_processing_perception_lm.py
new file mode 100644
index 0000000000000000000000000000000000000000..1023aa7c589dab5b249c6dff3d61e4d6e523eda8
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/perception_lm/video_processing_perception_lm.py
@@ -0,0 +1,41 @@
+# coding=utf-8
+# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Video processor class for PerceptionLM."""
+
+from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling
+from ...processing_utils import Unpack, VideosKwargs
+from ...video_processing_utils import BaseVideoProcessor
+
+
+class PerceptionLMFastVideoProcessorInitKwargs(VideosKwargs): ...
+
+
+class PerceptionLMVideoProcessor(BaseVideoProcessor):
+ resample = PILImageResampling.BICUBIC
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ size = {"height": 448, "width": 448}
+ do_resize = True
+ do_center_crop = False
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+ valid_kwargs = PerceptionLMFastVideoProcessorInitKwargs
+ model_input_names = ["pixel_values_videos"]
+
+ def __init__(self, **kwargs: Unpack[PerceptionLMFastVideoProcessorInitKwargs]):
+ super().__init__(**kwargs)
+
+
+__all__ = ["PerceptionLMVideoProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/persimmon/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/persimmon/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb71eae2547c59a4f9ba7ecbafda56fb9c86b494
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/persimmon/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_persimmon import *
+ from .modeling_persimmon import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/persimmon/configuration_persimmon.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/persimmon/configuration_persimmon.py
new file mode 100644
index 0000000000000000000000000000000000000000..3773ad4174d1a44095e3f75bcbd6cc8a440a24a4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/persimmon/configuration_persimmon.py
@@ -0,0 +1,176 @@
+# coding=utf-8
+# Copyright 2023 Adept AI and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Persimmon model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class PersimmonConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PersimmonModel`]. It is used to instantiate an
+ Persimmon model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the
+ [adept/persimmon-8b-base](https://huggingface.co/adept/persimmon-8b-base).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 262144):
+ Vocabulary size of the Persimmon model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`PersimmonModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 16384):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 36):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 64):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 16384):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 25000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ qk_layernorm (`bool`, *optional*, default to `True`):
+ Whether or not to normalize the Queries and Keys after projecting the hidden states
+ hidden_dropout (`float`, *optional*, default to 0.0):
+ The dropout ratio after applying the MLP to the hidden states.
+ attention_dropout (`float`, *optional*, default to 0.0):
+ The dropout ratio after computing the attention scores.
+ partial_rotary_factor (`float`, *optional*, default to 0.5):
+ Percentage of the query and keys which will have rotary embedding.
+
+ Example:
+
+ ```python
+ >>> from transformers import PersimmonModel, PersimmonConfig
+
+ >>> # Initializing a Persimmon persimmon-7b style configuration
+ >>> configuration = PersimmonConfig()
+ ```"""
+
+ model_type = "persimmon"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=262144,
+ hidden_size=4096,
+ intermediate_size=16384,
+ num_hidden_layers=36,
+ num_attention_heads=64,
+ hidden_act="relu2",
+ max_position_embeddings=16384,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=25000.0,
+ rope_scaling=None,
+ qk_layernorm=True,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ partial_rotary_factor=0.5,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.qk_layernorm = qk_layernorm
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+ self.partial_rotary_factor = partial_rotary_factor
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+__all__ = ["PersimmonConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/persimmon/modeling_persimmon.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/persimmon/modeling_persimmon.py
new file mode 100644
index 0000000000000000000000000000000000000000..c963bb53852a1be5ca1c6005883a2746287d30a2
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/persimmon/modeling_persimmon.py
@@ -0,0 +1,768 @@
+# coding=utf-8
+# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Persimmon model."""
+
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import (
+ GenericForSequenceClassification,
+ GenericForTokenClassification,
+ GradientCheckpointingLayer,
+)
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_persimmon import PersimmonConfig
+
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon
+class PersimmonRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: PersimmonConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXMLP with GPTNeoX->Persimmon
+class PersimmonMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.act = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense_h_to_4h(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.dense_4h_to_h(hidden_states)
+ return hidden_states
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class PersimmonAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: PersimmonConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.rope_theta = config.rope_theta
+ self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
+ self.is_causal = True
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
+ self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
+ self.qk_layernorm = config.qk_layernorm
+ self.scaling = self.head_dim**-0.5
+
+ if self.qk_layernorm:
+ self.q_layernorm = nn.LayerNorm(
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
+ )
+ self.k_layernorm = nn.LayerNorm(
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
+ )
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
+ self.rotary_emb = PersimmonRotaryEmbedding(config=self.config)
+
+ def _split_heads(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
+ storage as `fused_qkv`
+
+ Args:
+ fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim]
+
+ Returns:
+ query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
+ value: [batch_size, seq_length, num_heads, head_dim]
+ """
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
+ return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ # [batch_size, seq_length, 3 x hidden_size]
+ fused_qkv = self.query_key_value(hidden_states)
+
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
+ (query_states, key_states, value_states) = self._split_heads(fused_qkv)
+
+ if self.qk_layernorm:
+ query_states = self.q_layernorm(query_states)
+ key_states = self.k_layernorm(key_states)
+
+ # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim]
+ query_states = query_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+
+ cos, sin = position_embeddings
+
+ # Partial rotary embedding
+ query_rot, query_pass = (
+ query_states[..., : self.rotary_ndims],
+ query_states[..., self.rotary_ndims :],
+ )
+ key_rot, key_pass = (
+ key_states[..., : self.rotary_ndims],
+ key_states[..., self.rotary_ndims :],
+ )
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
+
+ # [batch_size, seq_length, num_heads, head_dim]
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
+
+ if past_key_values is not None:
+ # Specific to RoPE models with partial rotation
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "partial_rotation_size": self.rotary_ndims,
+ "cache_position": cache_position,
+ }
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.config.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.dense(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+class PersimmonDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: PersimmonConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = PersimmonAttention(config=config, layer_idx=layer_idx)
+ self.mlp = PersimmonMLP(config)
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
+ `[0, config.n_positions - 1]`.
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache`, *optional*):
+ cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states + residual
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class PersimmonPreTrainedModel(PreTrainedModel):
+ config: PersimmonConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["PersimmonDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+
+ _can_compile_fullgraph = True
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+
+
+@auto_docstring
+class PersimmonModel(PersimmonPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PersimmonDecoderLayer`]
+
+ Args:
+ config: PersimmonConfig
+ """
+
+ def __init__(self, config: PersimmonConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [PersimmonDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ self.rotary_emb = PersimmonRotaryEmbedding(config=config)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> BaseModelOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.final_layernorm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, "BlockMask"],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->PERSIMMON,Llama->Persimmon
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = PersimmonModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, PersimmonForCausalLM
+
+ >>> model = PersimmonForCausalLM.from_pretrained("adept/persimmon-8b-base")
+ >>> tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-base")
+
+ >>> prompt = "human: Hey, what should I eat for dinner?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ 'human: Hey, what should I eat for dinner?\n\ncat: 🐱\n\nhuman: 😐\n\n'
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # No upscaling to float was ever done for Persimmon
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class PersimmonForSequenceClassification(GenericForSequenceClassification, PersimmonPreTrainedModel): ...
+
+
+class PersimmonForTokenClassification(GenericForTokenClassification, PersimmonPreTrainedModel): ...
+
+
+__all__ = [
+ "PersimmonForCausalLM",
+ "PersimmonModel",
+ "PersimmonPreTrainedModel",
+ "PersimmonForSequenceClassification",
+ "PersimmonForTokenClassification",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e80bb47c302c88ffe33450d48a3b6ac6cf2ffc0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_poolformer import *
+ from .feature_extraction_poolformer import *
+ from .image_processing_poolformer import *
+ from .image_processing_poolformer_fast import *
+ from .modeling_poolformer import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/configuration_poolformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/configuration_poolformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaaa89f67048c9706bad29881c793fc1e312ef05
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/configuration_poolformer.py
@@ -0,0 +1,148 @@
+# coding=utf-8
+# Copyright 2022 Sea AI Labs and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PoolFormer model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class PoolFormerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of [`PoolFormerModel`]. It is used to instantiate a
+ PoolFormer model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the PoolFormer
+ [sail/poolformer_s12](https://huggingface.co/sail/poolformer_s12) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of channels in the input image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size of the input patch.
+ stride (`int`, *optional*, defaults to 16):
+ The stride of the input patch.
+ pool_size (`int`, *optional*, defaults to 3):
+ The size of the pooling window.
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ The ratio of the number of channels in the output of the MLP to the number of channels in the input.
+ depths (`list`, *optional*, defaults to `[2, 2, 6, 2]`):
+ The depth of each encoder block.
+ hidden_sizes (`list`, *optional*, defaults to `[64, 128, 320, 512]`):
+ The hidden sizes of each encoder block.
+ patch_sizes (`list`, *optional*, defaults to `[7, 3, 3, 3]`):
+ The size of the input patch for each encoder block.
+ strides (`list`, *optional*, defaults to `[4, 2, 2, 2]`):
+ The stride of the input patch for each encoder block.
+ padding (`list`, *optional*, defaults to `[2, 1, 1, 1]`):
+ The padding of the input patch for each encoder block.
+ num_encoder_blocks (`int`, *optional*, defaults to 4):
+ The number of encoder blocks.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ The dropout rate for the dropout layers.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The activation function for the hidden layers.
+ use_layer_scale (`bool`, *optional*, defaults to `True`):
+ Whether to use layer scale.
+ layer_scale_init_value (`float`, *optional*, defaults to 1e-05):
+ The initial value for the layer scale.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The initializer range for the weights.
+
+ Example:
+
+ ```python
+ >>> from transformers import PoolFormerConfig, PoolFormerModel
+
+ >>> # Initializing a PoolFormer sail/poolformer_s12 style configuration
+ >>> configuration = PoolFormerConfig()
+
+ >>> # Initializing a model (with random weights) from the sail/poolformer_s12 style configuration
+ >>> model = PoolFormerModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "poolformer"
+
+ def __init__(
+ self,
+ num_channels=3,
+ patch_size=16,
+ stride=16,
+ pool_size=3,
+ mlp_ratio=4.0,
+ depths=[2, 2, 6, 2],
+ hidden_sizes=[64, 128, 320, 512],
+ patch_sizes=[7, 3, 3, 3],
+ strides=[4, 2, 2, 2],
+ padding=[2, 1, 1, 1],
+ num_encoder_blocks=4,
+ drop_path_rate=0.0,
+ hidden_act="gelu",
+ use_layer_scale=True,
+ layer_scale_init_value=1e-5,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.stride = stride
+ self.padding = padding
+ self.pool_size = pool_size
+ self.hidden_sizes = hidden_sizes
+ self.mlp_ratio = mlp_ratio
+ self.depths = depths
+ self.patch_sizes = patch_sizes
+ self.strides = strides
+ self.num_encoder_blocks = num_encoder_blocks
+ self.drop_path_rate = drop_path_rate
+ self.hidden_act = hidden_act
+ self.use_layer_scale = use_layer_scale
+ self.layer_scale_init_value = layer_scale_init_value
+ self.initializer_range = initializer_range
+ super().__init__(**kwargs)
+
+
+class PoolFormerOnnxConfig(OnnxConfig):
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 2e-3
+
+
+__all__ = ["PoolFormerConfig", "PoolFormerOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/feature_extraction_poolformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/feature_extraction_poolformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bde18b3ec0960f4425a1788b24cf838ff6459922
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/feature_extraction_poolformer.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for PoolFormer."""
+
+import warnings
+
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_poolformer import PoolFormerImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class PoolFormerFeatureExtractor(PoolFormerImageProcessor):
+ def __init__(self, *args, **kwargs) -> None:
+ warnings.warn(
+ "The class PoolFormerFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
+ " Please use PoolFormerImageProcessor instead.",
+ FutureWarning,
+ )
+ super().__init__(*args, **kwargs)
+
+
+__all__ = ["PoolFormerFeatureExtractor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/image_processing_poolformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/image_processing_poolformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee5500c823ccee4b51431521a75fbb9841bc2f28
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/image_processing_poolformer.py
@@ -0,0 +1,360 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for PoolFormer."""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ get_resize_output_image_size,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+class PoolFormerImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a PoolFormer image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in the `preprocess` method.
+ size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
+ Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. If crop_pct is
+ unset:
+ - size is `{"height": h, "width": w}`: the image is resized to `(h, w)`.
+ - size is `{"shortest_edge": s}`: the shortest edge of the image is resized to s whilst maintaining the
+ aspect ratio.
+
+ If crop_pct is set:
+ - size is `{"height": h, "width": w}`: the image is resized to `(int(floor(h/crop_pct)),
+ int(floor(w/crop_pct)))`
+ - size is `{"height": c, "width": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
+ whilst maintaining the aspect ratio.
+ - size is `{"shortest_edge": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
+ whilst maintaining the aspect ratio.
+ crop_pct (`float`, *optional*, defaults to 0.9):
+ Percentage of the image to crop from the center. Can be overridden by `crop_pct` in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
+ is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in the `preprocess`
+ method.
+ crop_size (`dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Size of the image after applying center crop. Only has an effect if `do_center_crop` is set to `True`. Can
+ be overridden by the `crop_size` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
+ `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ crop_pct: int = 0.9,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_center_crop: bool = True,
+ crop_size: Optional[dict[str, int]] = None,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_rescale: bool = True,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 224}
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
+
+ self.do_resize = do_resize
+ self.size = size
+ self.crop_pct = crop_pct
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ crop_pct: Optional[float] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image.
+
+ If crop_pct is unset:
+ - size is `{"height": h, "width": w}`: the image is resized to `(h, w)`.
+ - size is `{"shortest_edge": s}`: the shortest edge of the image is resized to s whilst maintaining the
+ aspect ratio.
+
+ if crop_pct is set:
+ - size is `{"height": h, "width": w}`: the image is resized to `(int(floor(h/crop_pct)),
+ int(floor(w/crop_pct)))`
+ - size is `{"height": c, "width": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
+ whilst maintaining the aspect ratio.
+ - size is `{"shortest_edge": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
+ whilst maintaining the aspect ratio.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Size of the output image.
+ crop_pct (`float`, *optional*):
+ Percentage of the image that will be cropped from the center. If set, the image is resized
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ size = get_size_dict(size, default_to_square=False)
+ if "shortest_edge" not in size and ("height" not in size or "width" not in size):
+ raise ValueError(f"size must contain 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
+ if crop_pct is not None:
+ if "shortest_edge" in size:
+ scale_size = int(size["shortest_edge"] / crop_pct)
+ elif "height" in size and "width" in size:
+ if size["height"] == size["width"]:
+ scale_size = int(size["height"] / crop_pct)
+ else:
+ scale_size = (int(size["height"] / crop_pct), int(size["width"] / crop_pct))
+ else:
+ raise ValueError(f"Invalid size for resize: {size}")
+
+ output_size = get_resize_output_image_size(
+ image, size=scale_size, default_to_square=False, input_data_format=input_data_format
+ )
+ else:
+ if "shortest_edge" in size:
+ output_size = get_resize_output_image_size(
+ image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
+ )
+ elif "height" in size and "width" in size:
+ output_size = (size["height"], size["width"])
+ else:
+ raise ValueError(f"Invalid size for resize: {size}")
+
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ crop_pct: Optional[int] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[dict[str, int]] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after applying resize.
+ crop_pct (`float`, *optional*, defaults to `self.crop_pct`):
+ Percentage of the image to crop. Only has an effect if `do_resize` is set to `True`.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the image after applying center crop.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ crop_pct = crop_pct if crop_pct is not None else self.crop_pct
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
+
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if do_resize:
+ images = [
+ self.resize(
+ image=image, size=size, crop_pct=crop_pct, resample=resample, input_data_format=input_data_format
+ )
+ for image in images
+ ]
+
+ if do_center_crop:
+ images = [
+ self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
+ ]
+
+ if do_rescale:
+ images = [
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_normalize:
+ images = [
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["PoolFormerImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/image_processing_poolformer_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/image_processing_poolformer_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..62d5f276859f4ed754f36499316236eb962671b8
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/image_processing_poolformer_fast.py
@@ -0,0 +1,248 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for PoolFormer."""
+
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs
+from ...image_transforms import (
+ ChannelDimension,
+ get_resize_output_image_size,
+ get_size_with_aspect_ratio,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+ get_image_size_for_max_height_width,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ auto_docstring,
+)
+
+
+class PoolFormerFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ Args:
+ crop_pct (`float`, *optional*, defaults to `self.crop_pct`):
+ Percentage of the image to crop. Only has an effect if `do_resize` is set to `True`.
+ """
+
+ crop_pct: Optional[float]
+
+
+@auto_docstring
+class PoolFormerImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = IMAGENET_DEFAULT_MEAN
+ image_std = IMAGENET_DEFAULT_STD
+ size = {"shortest_edge": 224}
+ default_to_square = False
+ crop_size = {"height": 224, "width": 224}
+ crop_pct = 0.9
+ do_resize = True
+ do_center_crop = True
+ do_rescale = True
+ do_normalize = True
+ valid_kwargs = PoolFormerFastImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[PoolFormerFastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[PoolFormerFastImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def resize(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ crop_pct: Optional[float] = None,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ antialias: bool = True,
+ **kwargs,
+ ) -> "torch.Tensor":
+ """
+ Resize an image.
+
+ If crop_pct is unset:
+ - size is `{"height": h, "width": w}`: the image is resized to `(h, w)`.
+ - size is `{"shortest_edge": s}`: the shortest edge of the image is resized to s whilst maintaining the
+ aspect ratio.
+
+ if crop_pct is set:
+ - size is `{"height": h, "width": w}`: the image is resized to `(int(floor(h/crop_pct)),
+ int(floor(w/crop_pct)))`
+ - size is `{"height": c, "width": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
+ whilst maintaining the aspect ratio.
+ - size is `{"shortest_edge": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
+ whilst maintaining the aspect ratio.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to resize.
+ size (`SizeDict`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ crop_pct (`float`, *optional*):
+ Percentage of the image that will be cropped from the center. If set, the image is resized
+ resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
+
+ Returns:
+ `torch.Tensor`: The resized image.
+ """
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
+ if crop_pct is not None:
+ if size.shortest_edge:
+ scale_size = int(size.shortest_edge / crop_pct)
+ elif size.height and size.width:
+ if size.height == size.width:
+ scale_size = int(size.height / crop_pct)
+ else:
+ scale_size = (int(size.height / crop_pct), int(size.width / crop_pct))
+ else:
+ raise ValueError(f"Invalid size for resize: {size}")
+
+ new_size = get_resize_output_image_size(
+ image,
+ size=scale_size,
+ default_to_square=False,
+ input_data_format=ChannelDimension.FIRST,
+ )
+ else:
+ if size.shortest_edge and size.longest_edge:
+ # Resize the image so that the shortest edge or the longest edge is of the given size
+ # while maintaining the aspect ratio of the original image.
+ new_size = get_size_with_aspect_ratio(
+ image.size()[-2:],
+ size.shortest_edge,
+ size.longest_edge,
+ )
+ elif size.shortest_edge:
+ new_size = get_resize_output_image_size(
+ image,
+ size=size.shortest_edge,
+ default_to_square=False,
+ input_data_format=ChannelDimension.FIRST,
+ )
+ elif size.max_height and size.max_width:
+ new_size = get_image_size_for_max_height_width(image.size()[-2:], size.max_height, size.max_width)
+ elif size.height and size.width:
+ new_size = (size.height, size.width)
+ else:
+ raise ValueError(
+ "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
+ f" {size}."
+ )
+ return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
+
+ def center_crop(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ **kwargs,
+ ) -> "torch.Tensor":
+ """
+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
+ any edge, the image is padded with 0's and then center cropped.
+
+ Args:
+ image (`"torch.Tensor"`):
+ Image to center crop.
+ size (`dict[str, int]`):
+ Size of the output image.
+
+ Returns:
+ `torch.Tensor`: The center cropped image.
+ """
+ if size.height is None or size.width is None:
+ raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
+ image_height, image_width = image.shape[-2:]
+ crop_height, crop_width = size.height, size.width
+
+ if crop_width > image_width or crop_height > image_height:
+ padding_ltrb = [
+ (crop_width - image_width) // 2 if crop_width > image_width else 0,
+ (crop_height - image_height) // 2 if crop_height > image_height else 0,
+ (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
+ (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
+ ]
+ image = F.pad(image, padding_ltrb, fill=0) # PIL uses fill value 0
+ image_height, image_width = image.shape[-2:]
+ if crop_width == image_width and crop_height == image_height:
+ return image
+
+ crop_top = int((image_height - crop_height) / 2.0)
+ crop_left = int((image_width - crop_width) / 2.0)
+ return F.crop(image, crop_top, crop_left, crop_height, crop_width)
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ crop_pct: float,
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(
+ image=stacked_images, size=size, crop_pct=crop_pct, interpolation=interpolation
+ )
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_center_crop:
+ stacked_images = self.center_crop(stacked_images, crop_size)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+
+__all__ = ["PoolFormerImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/modeling_poolformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/modeling_poolformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c6dc8191630092ed965966209880bfc5360c7a4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/poolformer/modeling_poolformer.py
@@ -0,0 +1,380 @@
+# coding=utf-8
+# Copyright 2022 Sea AI Lab and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch PoolFormer model."""
+
+import collections.abc
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutputWithNoAttention, ImageClassifierOutputWithNoAttention
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging
+from .configuration_poolformer import PoolFormerConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->PoolFormer
+class PoolFormerDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return f"p={self.drop_prob}"
+
+
+class PoolFormerEmbeddings(nn.Module):
+ """
+ Construct Patch Embeddings.
+ """
+
+ def __init__(self, hidden_size, num_channels, patch_size, stride, padding, norm_layer=None):
+ super().__init__()
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride)
+ padding = padding if isinstance(padding, collections.abc.Iterable) else (padding, padding)
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=padding)
+ self.norm = norm_layer(hidden_size) if norm_layer else nn.Identity()
+
+ def forward(self, pixel_values):
+ embeddings = self.projection(pixel_values)
+ embeddings = self.norm(embeddings)
+ return embeddings
+
+
+class PoolFormerGroupNorm(nn.GroupNorm):
+ """
+ Group Normalization with 1 group. Input: tensor in shape [B, C, H, W]
+ """
+
+ def __init__(self, num_channels, **kwargs):
+ super().__init__(1, num_channels, **kwargs)
+
+
+class PoolFormerPooling(nn.Module):
+ def __init__(self, pool_size):
+ super().__init__()
+ self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)
+
+ def forward(self, hidden_states):
+ return self.pool(hidden_states) - hidden_states
+
+
+class PoolFormerOutput(nn.Module):
+ def __init__(self, config, dropout_prob, hidden_size, intermediate_size):
+ super().__init__()
+ self.conv1 = nn.Conv2d(hidden_size, intermediate_size, 1)
+ self.conv2 = nn.Conv2d(intermediate_size, hidden_size, 1)
+ self.drop = PoolFormerDropPath(dropout_prob)
+ if isinstance(config.hidden_act, str):
+ self.act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = self.drop(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.drop(hidden_states)
+
+ return hidden_states
+
+
+class PoolFormerLayer(nn.Module):
+ """This corresponds to the 'PoolFormerBlock' class in the original implementation."""
+
+ def __init__(self, config, num_channels, pool_size, hidden_size, intermediate_size, drop_path):
+ super().__init__()
+ self.pooling = PoolFormerPooling(pool_size)
+ self.output = PoolFormerOutput(config, drop_path, hidden_size, intermediate_size)
+ self.before_norm = PoolFormerGroupNorm(num_channels)
+ self.after_norm = PoolFormerGroupNorm(num_channels)
+
+ # Useful for training neural nets
+ self.drop_path = PoolFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.use_layer_scale = config.use_layer_scale
+ if config.use_layer_scale:
+ self.layer_scale_1 = nn.Parameter(
+ config.layer_scale_init_value * torch.ones(num_channels), requires_grad=True
+ )
+ self.layer_scale_2 = nn.Parameter(
+ config.layer_scale_init_value * torch.ones(num_channels), requires_grad=True
+ )
+
+ def forward(self, hidden_states):
+ if self.use_layer_scale:
+ pooling_output = self.pooling(self.before_norm(hidden_states))
+ scaled_op = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * pooling_output
+ # First residual connection
+ hidden_states = hidden_states + self.drop_path(scaled_op)
+ outputs = ()
+
+ layer_output = self.output(self.after_norm(hidden_states))
+ scaled_op = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * layer_output
+ # Second residual connection
+ output = hidden_states + self.drop_path(scaled_op)
+
+ outputs = (output,) + outputs
+ return outputs
+
+ else:
+ pooling_output = self.drop_path(self.pooling(self.before_norm(hidden_states)))
+ # First residual connection
+ hidden_states = pooling_output + hidden_states
+ outputs = ()
+
+ # Second residual connection inside the PoolFormerOutput block
+ layer_output = self.drop_path(self.output(self.after_norm(hidden_states)))
+ output = hidden_states + layer_output
+
+ outputs = (output,) + outputs
+ return outputs
+
+
+class PoolFormerEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
+
+ # patch embeddings
+ embeddings = []
+ for i in range(config.num_encoder_blocks):
+ embeddings.append(
+ PoolFormerEmbeddings(
+ patch_size=config.patch_sizes[i],
+ stride=config.strides[i],
+ padding=config.padding[i],
+ num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
+ hidden_size=config.hidden_sizes[i],
+ )
+ )
+ self.patch_embeddings = nn.ModuleList(embeddings)
+
+ # Transformer blocks
+ blocks = []
+ cur = 0
+ for i in range(config.num_encoder_blocks):
+ # each block consists of layers
+ layers = []
+ if i != 0:
+ cur += config.depths[i - 1]
+ for j in range(config.depths[i]):
+ layers.append(
+ PoolFormerLayer(
+ config,
+ num_channels=config.hidden_sizes[i],
+ pool_size=config.pool_size,
+ hidden_size=config.hidden_sizes[i],
+ intermediate_size=int(config.hidden_sizes[i] * config.mlp_ratio),
+ drop_path=dpr[cur + j],
+ )
+ )
+ blocks.append(nn.ModuleList(layers))
+
+ self.block = nn.ModuleList(blocks)
+
+ def forward(self, pixel_values, output_hidden_states=False, return_dict=True):
+ all_hidden_states = () if output_hidden_states else None
+
+ hidden_states = pixel_values
+ for idx, layers in enumerate(zip(self.patch_embeddings, self.block)):
+ embedding_layer, block_layer = layers
+ # Get patch embeddings from hidden_states
+ hidden_states = embedding_layer(hidden_states)
+ # Send the embeddings through the blocks
+ for _, blk in enumerate(block_layer):
+ layer_outputs = blk(hidden_states)
+ hidden_states = layer_outputs[0]
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
+
+
+@auto_docstring
+class PoolFormerPreTrainedModel(PreTrainedModel):
+ config: PoolFormerConfig
+ base_model_prefix = "poolformer"
+ main_input_name = "pixel_values"
+ _no_split_modules = ["PoolFormerLayer"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.GroupNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, PoolFormerLayer):
+ if hasattr(module, "layer_scale_1"):
+ module.layer_scale_1.data.fill_(self.config.layer_scale_init_value)
+ module.layer_scale_2.data.fill_(self.config.layer_scale_init_value)
+
+
+@auto_docstring
+class PoolFormerModel(PoolFormerPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.encoder = PoolFormerEncoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithNoAttention]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ encoder_outputs = self.encoder(
+ pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+
+ if not return_dict:
+ return (sequence_output, None) + encoder_outputs[1:]
+
+ return BaseModelOutputWithNoAttention(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+class PoolFormerFinalPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+
+ def forward(self, hidden_states):
+ output = self.dense(hidden_states)
+ return output
+
+
+@auto_docstring(
+ custom_intro="""
+ PoolFormer Model transformer with an image classification head on top
+ """
+)
+class PoolFormerForImageClassification(PoolFormerPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.poolformer = PoolFormerModel(config)
+
+ # Final norm
+ self.norm = PoolFormerGroupNorm(config.hidden_sizes[-1])
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.poolformer(
+ pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.classifier(self.norm(sequence_output).mean([-2, -1]))
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
+
+
+__all__ = ["PoolFormerForImageClassification", "PoolFormerModel", "PoolFormerPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..371478776da01819236b470d7a96717003a26d74
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_pvt import *
+ from .image_processing_pvt import *
+ from .image_processing_pvt_fast import *
+ from .modeling_pvt import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/configuration_pvt.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/configuration_pvt.py
new file mode 100644
index 0000000000000000000000000000000000000000..208295db71fb1343aebfc8730033fd76b4a8c31b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/configuration_pvt.py
@@ -0,0 +1,163 @@
+# coding=utf-8
+# Copyright 2023 Authors: Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan,
+# Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao and The HuggingFace Inc. team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Pvt model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+from typing import Callable
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class PvtConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PvtModel`]. It is used to instantiate an Pvt
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Pvt
+ [Xrenya/pvt-tiny-224](https://huggingface.co/Xrenya/pvt-tiny-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ image_size (`int`, *optional*, defaults to 224):
+ The input image size
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ num_encoder_blocks (`int`, *optional*, defaults to 4):
+ The number of encoder blocks (i.e. stages in the Mix Transformer encoder).
+ depths (`list[int]`, *optional*, defaults to `[2, 2, 2, 2]`):
+ The number of layers in each encoder block.
+ sequence_reduction_ratios (`list[int]`, *optional*, defaults to `[8, 4, 2, 1]`):
+ Sequence reduction ratios in each encoder block.
+ hidden_sizes (`list[int]`, *optional*, defaults to `[64, 128, 320, 512]`):
+ Dimension of each of the encoder blocks.
+ patch_sizes (`list[int]`, *optional*, defaults to `[4, 2, 2, 2]`):
+ Patch size before each encoder block.
+ strides (`list[int]`, *optional*, defaults to `[4, 2, 2, 2]`):
+ Stride before each encoder block.
+ num_attention_heads (`list[int]`, *optional*, defaults to `[1, 2, 5, 8]`):
+ Number of attention heads for each attention layer in each block of the Transformer encoder.
+ mlp_ratios (`list[int]`, *optional*, defaults to `[8, 8, 4, 4]`):
+ Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the
+ encoder blocks.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ The dropout probability for stochastic depth, used in the blocks of the Transformer encoder.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not a learnable bias should be added to the queries, keys and values.
+ num_labels ('int', *optional*, defaults to 1000):
+ The number of classes.
+ Example:
+
+ ```python
+ >>> from transformers import PvtModel, PvtConfig
+
+ >>> # Initializing a PVT Xrenya/pvt-tiny-224 style configuration
+ >>> configuration = PvtConfig()
+
+ >>> # Initializing a model from the Xrenya/pvt-tiny-224 style configuration
+ >>> model = PvtModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "pvt"
+
+ def __init__(
+ self,
+ image_size: int = 224,
+ num_channels: int = 3,
+ num_encoder_blocks: int = 4,
+ depths: list[int] = [2, 2, 2, 2],
+ sequence_reduction_ratios: list[int] = [8, 4, 2, 1],
+ hidden_sizes: list[int] = [64, 128, 320, 512],
+ patch_sizes: list[int] = [4, 2, 2, 2],
+ strides: list[int] = [4, 2, 2, 2],
+ num_attention_heads: list[int] = [1, 2, 5, 8],
+ mlp_ratios: list[int] = [8, 8, 4, 4],
+ hidden_act: Mapping[str, Callable] = "gelu",
+ hidden_dropout_prob: float = 0.0,
+ attention_probs_dropout_prob: float = 0.0,
+ initializer_range: float = 0.02,
+ drop_path_rate: float = 0.0,
+ layer_norm_eps: float = 1e-6,
+ qkv_bias: bool = True,
+ num_labels: int = 1000,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.image_size = image_size
+ self.num_channels = num_channels
+ self.num_encoder_blocks = num_encoder_blocks
+ self.depths = depths
+ self.sequence_reduction_ratios = sequence_reduction_ratios
+ self.hidden_sizes = hidden_sizes
+ self.patch_sizes = patch_sizes
+ self.strides = strides
+ self.mlp_ratios = mlp_ratios
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.drop_path_rate = drop_path_rate
+ self.layer_norm_eps = layer_norm_eps
+ self.num_labels = num_labels
+ self.qkv_bias = qkv_bias
+
+
+class PvtOnnxConfig(OnnxConfig):
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-4
+
+ @property
+ def default_onnx_opset(self) -> int:
+ return 12
+
+
+__all__ = ["PvtConfig", "PvtOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/image_processing_pvt.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/image_processing_pvt.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f687fe7548ff92b40a455b07e185b43794190a1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/image_processing_pvt.py
@@ -0,0 +1,276 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Pvt."""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class PvtImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a PVT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 224, "width": 224}
+ size = get_size_dict(size)
+ self.do_resize = do_resize
+ self.do_rescale = do_rescale
+ self.do_normalize = do_normalize
+ self.size = size
+ self.resample = resample
+ self.rescale_factor = rescale_factor
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
+
+ # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
+ output_size = (size["height"], size["width"])
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
+ resizing.
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
+ an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use if `do_normalize` is set to `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ resample = resample if resample is not None else self.resample
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size_dict = get_size_dict(size)
+
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if do_resize:
+ images = [
+ self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_rescale:
+ images = [
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_normalize:
+ images = [
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["PvtImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/image_processing_pvt_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/image_processing_pvt_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ad7a155d4ab74b477826db19b75bb32a734fc77
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/image_processing_pvt_fast.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for Pvt."""
+
+from ...image_processing_utils_fast import BaseImageProcessorFast
+from ...image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling
+from ...utils import auto_docstring
+
+
+@auto_docstring
+class PvtImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_DEFAULT_MEAN
+ image_std = IMAGENET_DEFAULT_STD
+ size = {"height": 224, "width": 224}
+ default_to_square = True
+ crop_size = None
+ do_resize = True
+ do_center_crop = None
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = None
+ model_input_names = ["pixel_values"]
+
+
+__all__ = ["PvtImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/modeling_pvt.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/modeling_pvt.py
new file mode 100644
index 0000000000000000000000000000000000000000..21af67542d704000157f21957f74c0b58d9a83bc
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/pvt/modeling_pvt.py
@@ -0,0 +1,591 @@
+# coding=utf-8
+# Copyright 2023 Authors: Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan,
+# Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao and The HuggingFace Inc. team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch PVT model."""
+
+import collections
+import math
+from collections.abc import Iterable
+from typing import Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import auto_docstring, logging
+from .configuration_pvt import PvtConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->Pvt
+class PvtDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return f"p={self.drop_prob}"
+
+
+class PvtPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(
+ self,
+ config: PvtConfig,
+ image_size: Union[int, Iterable[int]],
+ patch_size: Union[int, Iterable[int]],
+ stride: int,
+ num_channels: int,
+ hidden_size: int,
+ cls_token: bool = False,
+ ):
+ super().__init__()
+ self.config = config
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.position_embeddings = nn.Parameter(
+ torch.randn(1, num_patches + 1 if cls_token else num_patches, hidden_size)
+ )
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) if cls_token else None
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=stride, stride=patch_size)
+ self.layer_norm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ num_patches = height * width
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == self.config.image_size * self.config.image_size:
+ return self.position_embeddings
+ embeddings = embeddings.reshape(1, height, width, -1).permute(0, 3, 1, 2)
+ interpolated_embeddings = F.interpolate(embeddings, size=(height, width), mode="bilinear")
+ interpolated_embeddings = interpolated_embeddings.reshape(1, -1, height * width).permute(0, 2, 1)
+ return interpolated_embeddings
+
+ def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, int, int]:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ patch_embed = self.projection(pixel_values)
+ *_, height, width = patch_embed.shape
+ patch_embed = patch_embed.flatten(2).transpose(1, 2)
+ embeddings = self.layer_norm(patch_embed)
+ if self.cls_token is not None:
+ cls_token = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_token, embeddings), dim=1)
+ position_embeddings = self.interpolate_pos_encoding(self.position_embeddings[:, 1:], height, width)
+ position_embeddings = torch.cat((self.position_embeddings[:, :1], position_embeddings), dim=1)
+ else:
+ position_embeddings = self.interpolate_pos_encoding(self.position_embeddings, height, width)
+ embeddings = self.dropout(embeddings + position_embeddings)
+
+ return embeddings, height, width
+
+
+class PvtSelfOutput(nn.Module):
+ def __init__(self, config: PvtConfig, hidden_size: int):
+ super().__init__()
+ self.dense = nn.Linear(hidden_size, hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class PvtEfficientSelfAttention(nn.Module):
+ """Efficient self-attention mechanism with reduction of the sequence [PvT paper](https://huggingface.co/papers/2102.12122)."""
+
+ def __init__(
+ self, config: PvtConfig, hidden_size: int, num_attention_heads: int, sequences_reduction_ratio: float
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+
+ if self.hidden_size % self.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({self.num_attention_heads})"
+ )
+
+ self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(self.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ self.sequences_reduction_ratio = sequences_reduction_ratio
+ if sequences_reduction_ratio > 1:
+ self.sequence_reduction = nn.Conv2d(
+ hidden_size, hidden_size, kernel_size=sequences_reduction_ratio, stride=sequences_reduction_ratio
+ )
+ self.layer_norm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+
+ def transpose_for_scores(self, hidden_states: int) -> torch.Tensor:
+ new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ hidden_states = hidden_states.view(new_shape)
+ return hidden_states.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ height: int,
+ width: int,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor]:
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
+
+ if self.sequences_reduction_ratio > 1:
+ batch_size, seq_len, num_channels = hidden_states.shape
+ # Reshape to (batch_size, num_channels, height, width)
+ hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
+ # Apply sequence reduction
+ hidden_states = self.sequence_reduction(hidden_states)
+ # Reshape back to (batch_size, seq_len, num_channels)
+ hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1)
+ hidden_states = self.layer_norm(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class PvtAttention(nn.Module):
+ def __init__(
+ self, config: PvtConfig, hidden_size: int, num_attention_heads: int, sequences_reduction_ratio: float
+ ):
+ super().__init__()
+ self.self = PvtEfficientSelfAttention(
+ config,
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ sequences_reduction_ratio=sequences_reduction_ratio,
+ )
+ self.output = PvtSelfOutput(config, hidden_size=hidden_size)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self, hidden_states: torch.Tensor, height: int, width: int, output_attentions: bool = False
+ ) -> tuple[torch.Tensor]:
+ self_outputs = self.self(hidden_states, height, width, output_attentions)
+
+ attention_output = self.output(self_outputs[0])
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class PvtFFN(nn.Module):
+ def __init__(
+ self,
+ config: PvtConfig,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ ):
+ super().__init__()
+ out_features = out_features if out_features is not None else in_features
+ self.dense1 = nn.Linear(in_features, hidden_features)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+ self.dense2 = nn.Linear(hidden_features, out_features)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense1(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.dense2(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class PvtLayer(nn.Module):
+ def __init__(
+ self,
+ config: PvtConfig,
+ hidden_size: int,
+ num_attention_heads: int,
+ drop_path: float,
+ sequences_reduction_ratio: float,
+ mlp_ratio: float,
+ ):
+ super().__init__()
+ self.layer_norm_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+ self.attention = PvtAttention(
+ config=config,
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ sequences_reduction_ratio=sequences_reduction_ratio,
+ )
+ self.drop_path = PvtDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.layer_norm_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+ mlp_hidden_size = int(hidden_size * mlp_ratio)
+ self.mlp = PvtFFN(config=config, in_features=hidden_size, hidden_features=mlp_hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor, height: int, width: int, output_attentions: bool = False):
+ self_attention_outputs = self.attention(
+ hidden_states=self.layer_norm_1(hidden_states),
+ height=height,
+ width=width,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:]
+
+ attention_output = self.drop_path(attention_output)
+ hidden_states = attention_output + hidden_states
+
+ mlp_output = self.mlp(self.layer_norm_2(hidden_states))
+
+ mlp_output = self.drop_path(mlp_output)
+ layer_output = hidden_states + mlp_output
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+class PvtEncoder(nn.Module):
+ def __init__(self, config: PvtConfig):
+ super().__init__()
+ self.config = config
+
+ # stochastic depth decay rule
+ drop_path_decays = torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").tolist()
+
+ # patch embeddings
+ embeddings = []
+
+ for i in range(config.num_encoder_blocks):
+ embeddings.append(
+ PvtPatchEmbeddings(
+ config=config,
+ image_size=config.image_size if i == 0 else self.config.image_size // (2 ** (i + 1)),
+ patch_size=config.patch_sizes[i],
+ stride=config.strides[i],
+ num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
+ hidden_size=config.hidden_sizes[i],
+ cls_token=i == config.num_encoder_blocks - 1,
+ )
+ )
+ self.patch_embeddings = nn.ModuleList(embeddings)
+
+ # Transformer blocks
+ blocks = []
+ cur = 0
+ for i in range(config.num_encoder_blocks):
+ # each block consists of layers
+ layers = []
+ if i != 0:
+ cur += config.depths[i - 1]
+ for j in range(config.depths[i]):
+ layers.append(
+ PvtLayer(
+ config=config,
+ hidden_size=config.hidden_sizes[i],
+ num_attention_heads=config.num_attention_heads[i],
+ drop_path=drop_path_decays[cur + j],
+ sequences_reduction_ratio=config.sequence_reduction_ratios[i],
+ mlp_ratio=config.mlp_ratios[i],
+ )
+ )
+ blocks.append(nn.ModuleList(layers))
+
+ self.block = nn.ModuleList(blocks)
+
+ # Layer norms
+ self.layer_norm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ batch_size = pixel_values.shape[0]
+ num_blocks = len(self.block)
+ hidden_states = pixel_values
+ for idx, (embedding_layer, block_layer) in enumerate(zip(self.patch_embeddings, self.block)):
+ # first, obtain patch embeddings
+ hidden_states, height, width = embedding_layer(hidden_states)
+ # second, send embeddings through blocks
+ for block in block_layer:
+ layer_outputs = block(hidden_states, height, width, output_attentions)
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ if idx != num_blocks - 1:
+ hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
+ hidden_states = self.layer_norm(hidden_states)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+@auto_docstring
+class PvtPreTrainedModel(PreTrainedModel):
+ config: PvtConfig
+ base_model_prefix = "pvt"
+ main_input_name = "pixel_values"
+ _no_split_modules = []
+
+ def _init_weights(self, module: nn.Module) -> None:
+ """Initialize the weights"""
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, PvtPatchEmbeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data,
+ mean=0.0,
+ std=std,
+ )
+ if module.cls_token is not None:
+ module.cls_token.data = nn.init.trunc_normal_(
+ module.cls_token.data,
+ mean=0.0,
+ std=std,
+ )
+
+
+@auto_docstring
+class PvtModel(PvtPreTrainedModel):
+ def __init__(self, config: PvtConfig):
+ super().__init__(config)
+ self.config = config
+
+ # hierarchical Transformer encoder
+ self.encoder = PvtEncoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_outputs = self.encoder(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+
+ if not return_dict:
+ return (sequence_output,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Pvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
+ the [CLS] token) e.g. for ImageNet.
+ """
+)
+class PvtForImageClassification(PvtPreTrainedModel):
+ def __init__(self, config: PvtConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.pvt = PvtModel(config)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor],
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.pvt(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.classifier(sequence_output[:, 0, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["PvtForImageClassification", "PvtModel", "PvtPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d7ddae0da7e1b2c16b95e94e6678f9511cb716f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_qwen2_5_omni import *
+ from .modeling_qwen2_5_omni import *
+ from .processing_qwen2_5_omni import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bd36b7a3c0dda2da0ce85be66fa4d3705ba5b81
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py
@@ -0,0 +1,1091 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen2_5_omni.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...modeling_rope_utils import rope_config_validation
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen2_5OmniVisionEncoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerVision`]. It is used to instantiate a
+ Qwen2.5-VL vision encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2.5-VL
+ architecture.
+
+ e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ depth (`int`, *optional*, defaults to 32):
+ Number of layers (depth) in the model.
+ hidden_size (`int`, *optional*, defaults to 3584):
+ The size of the hidden layers.
+ hidden_act (`str`, *optional*, defaults to `"quick_gelu"`):
+ The non-linear activation function used in the model. Supported options include `"quick_gelu"` and others as applicable.
+ mlp_ratio (`float`, *optional*, defaults to 4):
+ The ratio used to determine the size of the MLP (Multi-Layer Perceptron) hidden layer.
+ num_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer.
+ in_channels (`int`, *optional*, defaults to 3):
+ Number of input channels.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size of the patches extracted from the input.
+ spatial_merge_size (`int`, *optional*, defaults to 2):
+ The size used for merging spatial dimensions.
+ temporal_patch_size (`int`, *optional*, defaults to 2):
+ The size used for patches along the temporal dimension.
+
+ Example:
+
+ ```python
+ >>> from transformers import Qwen2_5OmniVisionEncoderConfig, Qwen2_5OmniVisionEncoder
+
+ >>> # Initializing a Qwen2_5OmniVisionEncoderConfig
+ >>> configuration = Qwen2_5OmniVisionEncoderConfig()
+
+ >>> # Initializing a Qwen2_5OmniVisionEncoder (with random weights)
+ >>> model = Qwen2_5OmniVisionEncoder(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen2_5_omni_vision_encoder"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=32,
+ hidden_size=3584,
+ hidden_act="silu",
+ intermediate_size=3420,
+ num_heads=16,
+ in_channels=3,
+ patch_size=14,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ window_size=112,
+ out_hidden_size=3584,
+ fullatt_block_indexes=[7, 15, 23, 31],
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.window_size = window_size
+ self.fullatt_block_indexes = fullatt_block_indexes
+ self.out_hidden_size = out_hidden_size
+ self.initializer_range = initializer_range
+
+
+class Qwen2_5OmniAudioEncoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2_5OmniAudioEncoder`]. It is used to instantiate a
+ Qwen2.5-Omni-Thinker audio encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2-Audio
+ architecture.
+
+ e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_mel_bins (`int`, *optional*, defaults to 128):
+ Number of mel features used per input features. Should correspond to the value used in the
+ `Qwen2_5OmniProcessor` class.
+ encoder_layers (`int`, *optional*, defaults to 32):
+ Number of encoder layers.
+ encoder_attention_heads (`int`, *optional*, defaults to 20):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ encoder_ffn_dim (`int`, *optional*, defaults to 5120):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
+ d_model (`int`, *optional*, defaults to 1280):
+ Dimensionality of the layers.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_function (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ scale_embedding (`bool`, *optional*, defaults to `False`):
+ Scale embeddings by diving by sqrt(d_model).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ max_source_positions (`int`, *optional*, defaults to 1500):
+ The maximum sequence length of log-mel filter-bank features that this model might ever be used with.
+ n_window (`int`, *optional*, defaults to 100):
+ The chunk for conv and flash attn in AudioEncoder.
+ output_dim (`int`, *optional*, defaults to 3584):
+ The output dimension of AudioEncoder.
+
+ Example:
+
+ ```python
+ >>> from transformers import Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniAudioEncoder
+
+ >>> # Initializing a Qwen2_5OmniAudioEncoderConfig
+ >>> configuration = Qwen2_5OmniAudioEncoderConfig()
+
+ >>> # Initializing a Qwen2_5OmniAudioEncoder (with random weights)
+ >>> model = Qwen2_5OmniAudioEncoder(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen2_5_omni_audio_encoder"
+
+ def __init__(
+ self,
+ num_mel_bins=128,
+ encoder_layers=32,
+ encoder_attention_heads=20,
+ encoder_ffn_dim=5120,
+ d_model=1280,
+ dropout=0,
+ attention_dropout=0,
+ activation_function="gelu",
+ activation_dropout=0,
+ scale_embedding=False,
+ initializer_range=0.02,
+ max_source_positions=1500,
+ n_window=100,
+ output_dim=3584,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.num_mel_bins = num_mel_bins
+ self.d_model = d_model
+ self.encoder_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_function = activation_function
+ self.activation_dropout = activation_dropout
+ self.num_hidden_layers = encoder_layers
+ self.initializer_range = initializer_range
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+ self.max_source_positions = max_source_positions
+ self.n_window = n_window
+ self.output_dim = output_dim
+
+
+class Qwen2_5OmniTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerForConditionalGeneration`]. It is used to instantiate an
+ Qwen2.5-Omni-Thinker model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker.
+
+ e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 152064):
+ Vocabulary size of the QwenOmni model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Qwen2VLModel`]
+ hidden_size (`int`, *optional*, defaults to 3584):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 18944):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 28):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 28):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 4):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
+ Whether to use sliding window attention.
+ sliding_window (`int`, *optional*, defaults to 32768):
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
+ max_window_layers (`int`, *optional*, defaults to 28):
+ The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
+ additional layer afterwards will use SWA (Sliding Window Attention).
+ layer_types (`list`, *optional*):
+ Attention pattern for each layer.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+ Example:
+
+ ```python
+ >>> from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig
+
+ >>> # Initializing a Qwen2_5OmniAudioEncoder config
+ >>> audio_config = Qwen2_5OmniAudioEncoderConfig()
+
+ >>> # Initializing a Qwen2_5OmniVisionEncoder config
+ >>> vision_config = Qwen2_5OmniVisionEncoderConfig()
+
+ >>> # Initializing a Qwen2.5OmniThinker configuration
+ >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, vision_config)
+
+ >>> # Initializing a model from the Qwen-Omni style configuration
+ >>> model = Qwen2_5OmniThinkerForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen2_5_omni_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ # Default tensor parallel plan for base model `Qwen25OmniText`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=152064,
+ hidden_size=3584,
+ intermediate_size=18944,
+ num_hidden_layers=28,
+ num_attention_heads=28,
+ num_key_value_heads=4,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=1000000.0,
+ rope_scaling=None,
+ use_sliding_window=False,
+ sliding_window=32768,
+ max_window_layers=28,
+ layer_types=None,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.use_sliding_window = use_sliding_window
+ self.sliding_window = sliding_window if self.use_sliding_window else None
+ self.max_window_layers = max_window_layers
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_dropout = attention_dropout
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+ if self.rope_scaling is None:
+ self.rope_scaling = {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"}
+
+ self.layer_types = layer_types
+ if self.layer_types is None:
+ self.layer_types = [
+ "sliding_attention"
+ if self.sliding_window is not None and i >= self.max_window_layers
+ else "full_attention"
+ for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+
+class Qwen2_5OmniThinkerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerForConditionalGeneration`]. It is used to instantiate an
+ Qwen2.5-Omni-Thinker model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker.
+
+ e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ audio_config (`dict`, *optional*):
+ The config dictionary of the audio backbone.
+ vision_config (`dict`, *optional*):
+ The config dictionary of the vision backbone.
+ text_config (`dict`, *optional*):
+ The config dictionary of the text backbone.
+ audio_token_index (`int`, *optional*, defaults to 151646):
+ The audio token index to encode the audio prompt.
+ image_token_index (`int`, *optional*, defaults to 151655):
+ The image token index to encode the image prompt.
+ video_token_index (`int`, *optional*, defaults to 151656):
+ The video token index to encode the video prompt.
+ position_id_per_seconds (`int`, *optional*, defaults to 25):
+ The increment of position id per second.
+ seconds_per_chunk (`int`, *optional*, defaults to 2):
+ The duration in seconds of the chunk of audio and video data.
+ audio_start_token_id (`int`, *optional*, defaults to 151647):
+ The audio start token index to encode the audio prompt.
+ audio_end_token_id (`int`, *optional*, defaults to 151648):
+ The audio end token index to encode the audio prompt.
+ user_token_id (`int, *optional*, defaults to 872):
+ The user token index to encode the user token.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+ Example:
+
+ ```python
+ >>> from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig
+
+ >>> # Initializing a Qwen2_5OmniAudioEncoder config
+ >>> audio_config = Qwen2_5OmniAudioEncoderConfig()
+
+ >>> # Initializing a Qwen2_5OmniVisionEncoder config
+ >>> vision_config = Qwen2_5OmniVisionEncoderConfig()
+
+ >>> # Initializing a Qwen2_5OmniTextConfig config
+ >>> text_config = Qwen2_5OmniTextConfig()
+
+ >>> # Initializing a Qwen2.5OmniThinker configuration
+ >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, vision_config, text_config)
+
+ >>> # Initializing a model from the Qwen-Omni style configuration
+ >>> model = Qwen2_5OmniThinkerForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen2_5_omni_thinker"
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ "video_token_id": "video_token_index",
+ "audio_token_id": "audio_token_index",
+ }
+ sub_configs = {
+ "audio_config": Qwen2_5OmniAudioEncoderConfig,
+ "vision_config": Qwen2_5OmniVisionEncoderConfig,
+ "text_config": Qwen2_5OmniTextConfig,
+ }
+
+ def __init__(
+ self,
+ audio_config=None,
+ vision_config=None,
+ text_config=None,
+ audio_token_index=151646,
+ image_token_index=151655,
+ video_token_index=151656,
+ position_id_per_seconds=25,
+ seconds_per_chunk=2,
+ audio_start_token_id=151647,
+ audio_end_token_id=151648,
+ user_token_id=872,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ self.audio_token_index = audio_token_index
+ self.image_token_index = image_token_index
+ self.video_token_index = video_token_index
+ self.user_token_id = user_token_id
+ self.position_id_per_seconds = position_id_per_seconds
+ self.seconds_per_chunk = seconds_per_chunk
+ self.audio_start_token_id = audio_start_token_id
+ self.audio_end_token_id = audio_end_token_id
+ self.initializer_range = initializer_range
+
+ if isinstance(vision_config, dict):
+ vision_config = Qwen2_5OmniVisionEncoderConfig(**vision_config)
+ elif vision_config is None:
+ vision_config = Qwen2_5OmniVisionEncoderConfig()
+ self.vision_config = vision_config
+
+ if isinstance(audio_config, dict):
+ audio_config = Qwen2_5OmniAudioEncoderConfig(**audio_config)
+ elif audio_config is None:
+ audio_config = Qwen2_5OmniAudioEncoderConfig()
+ self.audio_config = audio_config
+
+ if isinstance(text_config, dict):
+ text_config = Qwen2_5OmniTextConfig(**text_config)
+ elif text_config is None:
+ text_config = Qwen2_5OmniTextConfig()
+ self.text_config = text_config
+
+ super().__init__(**kwargs)
+
+
+class Qwen2_5OmniTalkerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2_5OmniTalkerForConditionalGeneration`]. It is used to instantiate an
+ Qwen2.5-Omni-Talker model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker.
+
+ e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ audio_token_index (`int`, *optional*, defaults to 151646):
+ The audio token index to encode the audio prompt.
+ image_token_index (`int`, *optional*, defaults to 151655):
+ The image token index to encode the image prompt.
+ video_token_index (`int`, *optional*, defaults to 151656):
+ The video token index to encode the video prompt.
+ vocab_size (`int`, *optional*, defaults to 8448):
+ Vocabulary size of the QwenOmni model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Qwen2VLModel`]
+ tts_text_start_token_id (`int`, *optional*, defaults to 151860):
+ The tts text start token index to encode the start of tts text.
+ tts_text_end_token_id (`int`, *optional*, defaults to 151861):
+ The tts text end token index to encode the end of tts text.
+ tts_text_pad_token_id (`int`, *optional*, defaults to 151859):
+ The tts text pad token index to encode the pad of tts text.
+ tts_codec_start_token_id (`int`, *optional*, defaults to 8293):
+ The tts codec start token index to encode the start of tts codec.
+ tts_codec_end_token_id (`int`, *optional*, defaults to 8294):
+ The tts codec end token index to encode the end of tts codec.
+ tts_codec_pad_token_id (`int`, *optional*, defaults to 8292):
+ The tts codec pad token index to encode the pad of tts codec.
+ tts_codec_mask_token_id (`int`, *optional*, defaults to 8296):
+ The tts codec mask token index to encode the mask of tts codec.
+ vision_start_token_id (`int`, *optional*, defaults to 151652):
+ The tts vision start token index to encode the start of vision.
+ vision_end_token_id (`int`, *optional*, defaults to 151653):
+ The tts vision end token index to encode the end of vision.
+ embedding_size (`int`, *optional*, defaults to 3584):
+ Dimension of the embedding representations.
+ hidden_size (`int`, *optional*, defaults to 3584):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 18944):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 28):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 28):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 4):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ head_dim (`int`, *optional*, defaults to 128):
+ The dimension of each attention head.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
+ Whether to use sliding window attention.
+ sliding_window (`int`, *optional*, defaults to 32768):
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
+ max_window_layers (`int`, *optional*, defaults to 28):
+ The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
+ additional layer afterwards will use SWA (Sliding Window Attention).
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ position_id_per_seconds (`int`, *optional*, defaults to 25):
+ The increment of position id per second.
+ seconds_per_chunk (`int`, *optional*, defaults to 2):
+ The duration in seconds of the chunk of audio and video data.
+ audio_start_token_id (`int`, *optional*, defaults to 151647):
+ The audio start token index to encode the audio prompt.
+ audio_end_token_id (`int`, *optional*, defaults to 151648):
+ The audio end token index to encode the audio prompt.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ spatial_merge_size (`int`, *optional*, defaults to 2):
+ The size used for merging spatial dimensions.
+ layer_types (`list`, *optional*):
+ Attention pattern for each layer.
+
+ Example:
+
+ ```python
+ >>> from transformers import Qwen2_5OmniTalkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig
+
+ >>> # Initializing a Qwen2_5OmniAudioEncoder config
+ >>> audio_config = Qwen2_5OmniAudioEncoderConfig()
+
+ >>> # Initializing a Qwen2 config
+ >>> text_config = Qwen2Config()
+
+ >>> # Initializing a Qwen2_5Omni configuration
+ >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, text_config)
+
+ >>> # Initializing a model from the qwen2-audio style configuration
+ >>> model = Qwen2_5OmniTalkerForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen2_5_omni_talker"
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ "video_token_id": "video_token_index",
+ "audio_token_id": "audio_token_index",
+ }
+
+ def __init__(
+ self,
+ audio_token_index=151646,
+ image_token_index=151655,
+ video_token_index=151656,
+ vocab_size=8448,
+ tts_text_start_token_id=151860,
+ tts_text_end_token_id=151861,
+ tts_text_pad_token_id=151859,
+ tts_codec_start_token_id=8293,
+ tts_codec_end_token_id=8294,
+ tts_codec_pad_token_id=8292,
+ tts_codec_mask_token_id=8296,
+ vision_start_token_id=151652,
+ vision_end_token_id=151653,
+ embedding_size=3584,
+ hidden_size=3584,
+ intermediate_size=18944,
+ num_hidden_layers=28,
+ num_attention_heads=28,
+ num_key_value_heads=4,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ rms_norm_eps=1e-06,
+ head_dim=128,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=1000000.0,
+ use_sliding_window=False,
+ sliding_window=32768,
+ max_window_layers=28,
+ attention_dropout=0.0,
+ rope_scaling=None,
+ position_id_per_seconds=25,
+ seconds_per_chunk=2,
+ audio_start_token_id=151647,
+ audio_end_token_id=151648,
+ initializer_range=0.02,
+ spatial_merge_size=2,
+ layer_types=None,
+ **kwargs,
+ ):
+ self.audio_token_index = audio_token_index
+ self.image_token_index = image_token_index
+ self.video_token_index = video_token_index
+
+ self.tts_text_start_token_id = tts_text_start_token_id
+ self.tts_text_end_token_id = tts_text_end_token_id
+ self.tts_text_pad_token_id = tts_text_pad_token_id
+ self.tts_codec_start_token_id = tts_codec_start_token_id
+ self.tts_codec_end_token_id = tts_codec_end_token_id
+ self.tts_codec_pad_token_id = tts_codec_pad_token_id
+
+ self.tts_codec_mask_token_id = tts_codec_mask_token_id
+
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+
+ self.vocab_size = vocab_size
+ self.head_dim = head_dim
+ self.embedding_size = embedding_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.use_sliding_window = use_sliding_window
+ self.sliding_window = sliding_window if self.use_sliding_window else None
+ self.max_window_layers = max_window_layers
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+ self.rope_scaling = rope_scaling
+ self.position_id_per_seconds = position_id_per_seconds # zf
+ self.seconds_per_chunk = seconds_per_chunk # zf
+ self.audio_start_token_id = audio_start_token_id # zf
+ self.audio_end_token_id = audio_end_token_id # zf
+
+ self.initializer_range = initializer_range
+ self.spatial_merge_size = spatial_merge_size
+
+ self.layer_types = layer_types
+ if self.layer_types is None:
+ self.layer_types = [
+ "sliding_attention"
+ if self.sliding_window is not None and i >= self.max_window_layers
+ else "full_attention"
+ for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+class Qwen2_5OmniDiTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of the Qwen2_5OmniToken2WavDiT used in the Qwen2.5-Omni-Token2Wav model.
+ It defines the architecture of the DiT model, which is used for generating mel-spectrograms from tokens.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 1024):
+ The dimension of the model.
+ num_hidden_layers (`int`, *optional*, defaults to 22):
+ The number of transformer blocks in the DiT model.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ The number of attention heads in each transformer block.
+ ff_mult (`int`, *optional*, defaults to 2):
+ The multiplier for the feedforward layer in each transformer block.
+ emb_dim (`int`, *optional*, defaults to 512):
+ The dimension of the embedding layer.
+ head_dim (`int`, *optional*, defaults to 64):
+ The dimension of each attention head.
+ repeats (`int`, *optional*, defaults to 2):
+ The number of times the codec embeddings are repeated.
+ num_embeds (`int`, *optional*, defaults to 8193):
+ The number of unique embeddings in the codec.
+ mel_dim (`int`, *optional*, defaults to 80):
+ The dimension of the mel-spectrogram.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout rate for the transformer blocks.
+
+ enc_emb_dim (`int`, *optional*, defaults to 192):
+ The dimension of the pre-trained speaker embedding.
+ enc_dim (`int`, *optional*, defaults to 128):
+ The dimension of the encoder output.
+ enc_channels (`list[int]`, *optional*, defaults to `[256, 256, 256, 256, 768]`):
+ A list of output channels for each TDNN/SERes2Net layer in the encoder.
+ enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
+ A list of kernel sizes for each layer in the encoder.
+ enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
+ A list of dilations for each layer in the encoder.
+ enc_attention_channels (`int`, *optional*, defaults to 64):
+ The number of attention channels in the SqueezeExcitationBlock.
+ enc_res2net_scale (`int`, *optional*, defaults to 2):
+ The scale of the Res2Net block in the encoder.
+ enc_se_channels (`int`, *optional*, defaults to 64):
+ The number of output channels after squeeze in the SqueezeExcitationBlock.
+ """
+
+ model_type = "qwen2_5_omni_dit"
+
+ def __init__(
+ self,
+ hidden_size=1024,
+ num_hidden_layers=22,
+ num_attention_heads=16,
+ ff_mult=2,
+ emb_dim=512,
+ head_dim=64,
+ rope_theta=10000.0,
+ max_position_embeddings=32768,
+ block_size=24,
+ look_ahead_layers=[10],
+ look_backward_layers=[0, 20],
+ repeats=2,
+ num_embeds=8193,
+ mel_dim=80,
+ dropout=0.1,
+ enc_emb_dim=192,
+ enc_dim=128,
+ enc_channels=[256, 256, 256, 256, 768],
+ enc_kernel_sizes=[5, 3, 3, 3, 1],
+ enc_dilations=[1, 2, 3, 4, 1],
+ enc_attention_channels=64,
+ enc_res2net_scale=2,
+ enc_se_channels=64,
+ **kwargs,
+ ):
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.ff_mult = ff_mult
+ self.emb_dim = emb_dim
+ self.head_dim = head_dim
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+ self.block_size = block_size
+ self.look_ahead_layers = look_ahead_layers
+ self.look_backward_layers = look_backward_layers
+ self.repeats = repeats
+ self.num_embeds = num_embeds
+ self.mel_dim = mel_dim
+ self.dropout = dropout
+ self.enc_emb_dim = enc_emb_dim
+ self.enc_dim = enc_dim
+ self.enc_channels = enc_channels
+ self.enc_kernel_sizes = enc_kernel_sizes
+ self.enc_dilations = enc_dilations
+ self.enc_attention_channels = enc_attention_channels
+ self.enc_res2net_scale = enc_res2net_scale
+ self.enc_se_channels = enc_se_channels
+ super().__init__(**kwargs)
+
+
+class Qwen2_5OmniBigVGANConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of the Qwen2_5OmniToken2WavBigVGAN module used in the Qwen2.5-Omni-Token2Wav model.
+ It defines the architecture of the BigVGAN model, which is used for converting mel-spectrograms to waveforms.
+
+ Args:
+ mel_dim (`int`, *optional*, defaults to 80):
+ The dimension of the mel-spectrogram.
+ upsample_initial_channel (`int`, *optional*, defaults to 1536):
+ The number of channels in the initial upsampling layer.
+ resblock_kernel_sizes (`list[int]`, *optional*, defaults to `[3, 7, 11]`):
+ A list of kernel sizes for each residual block.
+ resblock_dilation_sizes (`list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
+ A list of dilation sizes for each residual block.
+ upsample_rates (`list[int]`, *optional*, defaults to `[5, 3, 2, 2, 2, 2]`):
+ A list of upsampling rates for each upsampling layer.
+ upsample_kernel_sizes (`list[int]`, *optional*, defaults to `[11, 7, 4, 4, 4, 4]`):
+ A list of kernel sizes for each upsampling layer.
+ """
+
+ model_type = "qwen2_5_omni_bigvgan"
+
+ def __init__(
+ self,
+ mel_dim=80,
+ upsample_initial_channel=1536,
+ resblock_kernel_sizes=[3, 7, 11],
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ upsample_rates=[5, 3, 2, 2, 2, 2],
+ upsample_kernel_sizes=[11, 7, 4, 4, 4, 4],
+ **kwargs,
+ ):
+ self.mel_dim = mel_dim
+ self.upsample_initial_channel = upsample_initial_channel
+ self.resblock_kernel_sizes = resblock_kernel_sizes
+ self.resblock_dilation_sizes = resblock_dilation_sizes
+ self.upsample_rates = upsample_rates
+ self.upsample_kernel_sizes = upsample_kernel_sizes
+ super().__init__(**kwargs)
+
+
+class Qwen2_5OmniToken2WavConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2_5OmniToken2WavModel`].
+ It is used to instantiate the Qwen2.5-Omni-Token2Wav model which combines a Diffusion Transformer (DiT) for mel-spectrogram generation with a BigVGAN model for waveform synthesis. The configuration contains sub-configurations for both components.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ dit_config ([`DiT_Args`], *optional*):
+ Configuration class for the Diffusion Transformer (DiT) module responsible for generating mel-spectrograms.
+ bigvgan_config ([`BigVGAN_Args`], *optional*):
+ Configuration class for the BigVGAN module responsible for converting mel-spectrograms to waveforms.
+ Example:
+
+ ```python
+ >>> from transformers import Qwen2_5OmniToken2WavModel, DiT_Args, BigVGAN_Args
+
+ >>> # Initialize DiT configuration
+ >>> dit_config = DiT_Args(
+ ... dim=1024,
+ ... depth=22,
+ ... heads=16,
+ ... ff_mult=2
+ ... )
+
+ >>> # Initialize BigVGAN configuration
+ >>> bigvgan_config = BigVGAN_Args(
+ ... mel_dim=80,
+ ... upsample_rates=[5,3,2,2,2,2]
+ ... )
+
+ >>> # Initialize main configuration
+ >>> config = Qwen2_5OmniToken2WavConfig(dit_config, bigvgan_config)
+
+ >>> # Initialize model with config
+ >>> model = Qwen2_5OmniToken2Wav(config)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "qwen2_5_omni_token2wav"
+ sub_configs = {
+ "dit_config": Qwen2_5OmniDiTConfig,
+ "bigvgan_config": Qwen2_5OmniBigVGANConfig,
+ }
+
+ def __init__(self, dit_config=None, bigvgan_config=None, **kwargs):
+ if dit_config is None:
+ dit_config = {}
+ if bigvgan_config is None:
+ bigvgan_config = {}
+ self.dit_config = Qwen2_5OmniDiTConfig(**dit_config)
+ self.bigvgan_config = Qwen2_5OmniBigVGANConfig(**bigvgan_config)
+ super().__init__(**kwargs)
+
+
+class Qwen2_5OmniConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`Qwen2_5OmniForConditionalGeneration`]. It is used to instantiate a Qwen2.5Omni
+ model according to the specified sub-models configurations, defining the model architecture.
+
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the
+ [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ thinker_config (`dict`, *optional*): Configuration of the underlying thinker sub-model.
+ talker_config (`dict`, *optional*): Configuration of the underlying talker sub-model.
+ token2wav_config (`dict`, *optional*): Configuration of the underlying codec sub-model.
+ enable_audio_output (`bool`, *optional*, defaults to `True`): Whether enable audio output and load talker and token2wav module.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... Qwen2_5OmniThinkerConfig,
+ ... Qwen2_5OmniTalkerConfig,
+ ... Qwen2_5OmniToken2WavConfig,
+ ... Qwen2_5OmniForConditionalGeneration,
+ ... Qwen2_5OmniConfig,
+ ... )
+
+ >>> # Initializing sub-modules configurations.
+ >>> thinker_config = Qwen2_5OmniThinkerConfig()
+ >>> talker_config = Qwen2_5OmniTalkerConfig()
+ >>> token2wav_config = Qwen2_5OmniToken2WavConfig()
+
+
+ >>> # Initializing a module style configuration
+ >>> configuration = Qwen2_5OmniConfig.from_sub_model_configs(
+ ... thinker_config, talker_config, token2wav_config
+ ... )
+
+ >>> # Initializing a model (with random weights)
+ >>> model = Qwen2_5OmniForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "qwen2_5_omni"
+ sub_configs = {
+ "thinker_config": Qwen2_5OmniThinkerConfig,
+ "talker_config": Qwen2_5OmniTalkerConfig,
+ "token2wav_config": Qwen2_5OmniToken2WavConfig,
+ }
+
+ def __init__(
+ self,
+ thinker_config=None,
+ talker_config=None,
+ token2wav_config=None,
+ enable_audio_output: bool = True,
+ **kwargs,
+ ):
+ if thinker_config is None:
+ thinker_config = {}
+ logger.info("thinker_config is None. Initializing thinker model with default values")
+
+ if talker_config is None:
+ talker_config = {}
+ logger.info("talker_config is None. Initializing talker model with default values")
+
+ if token2wav_config is None:
+ token2wav_config = {}
+ logger.info("token2wav_config is None. Initializing token2wav model with default values")
+
+ self.thinker_config = Qwen2_5OmniThinkerConfig(**thinker_config)
+ self.talker_config = Qwen2_5OmniTalkerConfig(**talker_config)
+ self.token2wav_config = Qwen2_5OmniToken2WavConfig(**token2wav_config)
+ self.enable_audio_output = enable_audio_output
+
+ super().__init__(**kwargs)
+
+ def get_text_config(self, *args, **kwargs):
+ """
+ Returns the config that is meant to be used with text IO. On most models, it is the original config instance
+ itself. On specific composite models, it is under a set of valid names.
+
+ Args:
+ decoder (`Optional[bool]`, *optional*, defaults to `False`):
+ If set to `True`, then only search for decoder config names.
+ """
+ # Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
+ # except for Qwen yet. This has to be generalized if more deeply nested configs are
+ # added. NOTE: currently method used only by vLLM
+ return self.thinker_config.get_text_config(*args, **kwargs)
+
+
+__all__ = ["Qwen2_5OmniConfig", "Qwen2_5OmniThinkerConfig", "Qwen2_5OmniTalkerConfig", "Qwen2_5OmniToken2WavConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b8910d270bb58a723cfe1766bcc94721950f4f1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
@@ -0,0 +1,4020 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen2_5_omni.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import Parameter
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, check_torch_load_is_safe, logging
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.hub import cached_file
+from ..qwen2.modeling_qwen2 import Qwen2RMSNorm
+from .configuration_qwen2_5_omni import (
+ Qwen2_5OmniAudioEncoderConfig,
+ Qwen2_5OmniBigVGANConfig,
+ Qwen2_5OmniConfig,
+ Qwen2_5OmniDiTConfig,
+ Qwen2_5OmniTalkerConfig,
+ Qwen2_5OmniTextConfig,
+ Qwen2_5OmniThinkerConfig,
+ Qwen2_5OmniToken2WavConfig,
+ Qwen2_5OmniVisionEncoderConfig,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+@auto_docstring
+class Qwen2_5OmniPreTrainedModel(PreTrainedModel):
+ config: Qwen2_5OmniConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Qwen2_5OmniDecoderLayer", "Qwen2_5OmniVisionBlock"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _can_compile_fullgraph = False
+ _supports_attention_backend = True
+
+
+class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel):
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ self,
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ min_dtype: float,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to place the 4D attention mask on.
+ min_dtype (`float`):
+ The minimum value representable with the dtype `dtype`.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+ def get_llm_pos_ids_for_vision(
+ self,
+ start_idx: int,
+ vision_idx: int,
+ spatial_merge_size: int,
+ t_index: list[int],
+ grid_hs: list[int],
+ grid_ws: list[int],
+ ):
+ llm_pos_ids_list = []
+ llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
+ llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten()
+ t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().long()
+ _llm_pos_ids = torch.stack([t_index, h_index, w_index])
+ llm_pos_ids_list.append(_llm_pos_ids + start_idx) # + 1 ) # 12.09 by malinhan
+ llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
+ return llm_pos_ids
+
+ def get_chunked_index(
+ self, token_indices: torch.Tensor, tokens_per_chunk: int, remove_index: int
+ ) -> list[tuple[int, int]]:
+ """
+ Splits token index list into chunks based on token value ranges.
+
+ Given a list of token indices, returns a list of (start, end) index tuples representing
+ slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`.
+
+ For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that:
+ - the first chunk contains token values < 1000,
+ - the second chunk contains values >= 1000 and < 2000, and so on.
+
+ Parameters:
+ token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of
+ token index values.
+ t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
+ remove_index (`int`) An index id to subtract from `token_indices` before chunking
+
+ Returns:
+ `list[tuple[int, int]]`: A list of tuples, each representing the start (inclusive)
+ and end (exclusive) indices of a chunk in `token_indices`.
+ """
+
+ def _iter():
+ i, start_idx = 0, 0 # skip bos token
+ current_chunk = 1
+ while i < len(token_indices): # skip eos token
+ if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk:
+ yield (start_idx, i)
+ start_idx = i
+ current_chunk += 1
+ i += 1
+ yield (start_idx, len(token_indices))
+
+ return list(_iter())
+
+ def get_rope_index(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ use_audio_in_video: bool = False,
+ audio_seqlens: Optional[torch.LongTensor] = None,
+ second_per_grids: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
+
+ Explanation:
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
+
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
+ Examples:
+ input_ids: [T T T T T], here T is for text.
+ temporal position_ids: [0, 1, 2, 3, 4]
+ height position_ids: [0, 1, 2, 3, 4]
+ width position_ids: [0, 1, 2, 3, 4]
+
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
+ and 1D rotary position embedding for text part.
+ Examples:
+ Temporal (Time): 3 patches, representing different segments of the video in time.
+ Height: 2 patches, dividing each frame vertically.
+ Width: 2 patches, dividing each frame horizontally.
+ We also have some important parameters:
+ fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
+ tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
+ temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
+ interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
+ vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
+ text temporal position_ids: [101, 102, 103, 104, 105]
+ text height position_ids: [101, 102, 103, 104, 105]
+ text width position_ids: [101, 102, 103, 104, 105]
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ use_audio_in_video (`bool`, *optional*):
+ If set to `True`, use the audio in video.
+ audio_seqlens (`torch.LongTensor` of shape `(num_audios)`, *optional*):
+ The length of feature shape of each audio in LLM.
+ second_per_grids (`torch.LongTensor` of shape `(num_videos)`, *optional*):
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
+
+ Returns:
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
+ """
+ spatial_merge_size = self.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ audio_token_id = self.config.audio_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+ audio_start_token_id = self.config.audio_start_token_id
+ position_id_per_seconds = self.config.position_id_per_seconds
+ seconds_per_chunk = self.config.seconds_per_chunk
+
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is not None:
+ attention_mask = attention_mask == 1
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_idx, video_idx, audio_idx = 0, 0, 0
+ for i, input_ids in enumerate(total_input_ids):
+ if attention_mask is not None:
+ input_ids = input_ids[attention_mask[i]]
+ image_nums, video_nums, audio_nums = 0, 0, 0
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ audio_nums = torch.sum(input_ids == audio_start_token_id)
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (
+ (vision_tokens == audio_start_token_id).sum()
+ if use_audio_in_video
+ else (vision_tokens == video_token_id).sum()
+ )
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums
+ multimodal_nums = (
+ image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums
+ )
+ for _ in range(multimodal_nums):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if audio_token_id in input_tokens and remain_audios > 0:
+ ed_audio = input_tokens.index(audio_token_id, st)
+ else:
+ ed_audio = len(input_tokens) + 1
+ min_ed = min(ed_image, ed_video, ed_audio)
+ if min_ed == ed_audio:
+ text_len = min_ed - st - 1
+ if text_len != 0:
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ bos_len = 1
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1
+ llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
+ llm_pos_ids_list.append(llm_pos_ids)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ eos_len = 1
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st += text_len + bos_len + audio_len + eos_len
+ audio_idx += 1
+ remain_audios -= 1
+
+ elif min_ed == ed_image:
+ text_len = min_ed - st - 1
+ if text_len != 0:
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ bos_len = 1
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ grid_t = image_grid_thw[image_idx][0]
+ grid_hs = image_grid_thw[:, 1]
+ grid_ws = image_grid_thw[:, 2]
+ t_index = (torch.arange(grid_t) * 1 * position_id_per_seconds).long()
+ llm_pos_ids = self.get_llm_pos_ids_for_vision(
+ st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
+ )
+ image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
+ llm_pos_ids_list.append(llm_pos_ids)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ eos_len = 1
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st += text_len + bos_len + image_len + eos_len
+ image_idx += 1
+ remain_images -= 1
+
+ elif min_ed == ed_video and not use_audio_in_video:
+ text_len = min_ed - st - 1
+ if text_len != 0:
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ bos_len = 1
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ grid_t = video_grid_thw[video_idx][0]
+ grid_hs = video_grid_thw[:, 1]
+ grid_ws = video_grid_thw[:, 2]
+ t_index = (
+ torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds
+ ).long()
+ llm_pos_ids = self.get_llm_pos_ids_for_vision(
+ st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
+ )
+ video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
+ llm_pos_ids_list.append(llm_pos_ids)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ eos_len = 1
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st += text_len + bos_len + video_len + eos_len
+ video_idx += 1
+ remain_videos -= 1
+
+ elif min_ed == ed_video and use_audio_in_video:
+ text_len = min_ed - st - 2
+ if text_len != 0:
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ bos_len = 1
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1
+ audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
+ grid_t = video_grid_thw[video_idx][0]
+ grid_hs = video_grid_thw[:, 1]
+ grid_ws = video_grid_thw[:, 2]
+
+ t_index = (
+ torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds
+ ).long()
+ video_llm_pos_ids = self.get_llm_pos_ids_for_vision(
+ st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
+ )
+
+ t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk)
+ video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx)
+ audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx)
+ sub_len = 0
+ for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
+ video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None
+ audio_chunk_index = audio_chunk_indexes[j] if j < len(audio_chunk_indexes) else None
+ if video_chunk_index is not None:
+ sub_len += video_chunk_index[1] - video_chunk_index[0]
+
+ llm_pos_ids_list.append(
+ video_llm_pos_ids[:, video_chunk_index[0] : video_chunk_index[1]]
+ )
+ if audio_chunk_index is not None:
+ sub_len += audio_chunk_index[1] - audio_chunk_index[0]
+
+ llm_pos_ids_list.append(
+ audio_llm_pos_ids[:, audio_chunk_index[0] : audio_chunk_index[1]]
+ )
+ video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ eos_len = 1
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2
+
+ audio_idx += 1
+ video_idx += 1
+ remain_videos -= 1
+ remain_audios -= 1
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+
+ if attention_mask is not None:
+ position_ids[..., i, attention_mask[i]] = llm_positions.to(position_ids.device)
+ else:
+ position_ids[..., i, :] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=input_ids.device)
+
+ return position_ids, mrope_position_deltas
+ else:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)
+
+ return position_ids, mrope_position_deltas
+
+
+############################
+# Start Thinker #
+############################
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Qwen2.5OmniThinker causal language model (or autoregressive) outputs.
+ """
+)
+class Qwen2_5OmniThinkerCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Qwen2_5OmniAudioAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ config: Qwen2_5OmniAudioEncoderConfig,
+ ):
+ super().__init__()
+ self.embed_dim = config.d_model
+ self.num_heads = config.encoder_attention_heads
+ self.dropout = config.attention_dropout
+ self.head_dim = self.embed_dim // self.num_heads
+ self.num_key_value_groups = 1 # needed for eager attention
+ self.config = config
+
+ if (self.head_dim * self.num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = 0.0
+ self.is_decoder = False
+ self.is_causal = False
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ seq_length, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
+ key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
+ value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
+
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, _ = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+
+class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Qwen2_5OmniAudioEncoderConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+ self.self_attn = Qwen2_5OmniAudioAttention(config)
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ cu_seqlens=cu_seqlens,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = residual + hidden_states
+
+ if hidden_states.dtype == torch.float16:
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ return outputs
+
+
+class SinusoidsPositionEmbedding(nn.Module):
+ def __init__(self, length, channels, max_timescale=10000):
+ super().__init__()
+ if channels % 2 != 0:
+ raise ValueError("SinusoidsPositionEmbedding needs even channels input")
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+ self.register_buffer(
+ "positional_embedding",
+ torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
+ persistent=False,
+ )
+
+ def forward(self, seqlen: int):
+ return self.positional_embedding[:seqlen, :]
+
+
+@auto_docstring(
+ custom_intro="""
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`Qwen2_5OmniAudioEncoderLayer`].
+ """
+)
+class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
+ config: Qwen2_5OmniAudioEncoderConfig
+ main_input_name = "input_features"
+ _no_split_modules = ["Qwen2_5OmniAudioEncoderLayer"]
+ _supports_sdpa = True
+
+ def __init__(self, config: Qwen2_5OmniAudioEncoderConfig):
+ super().__init__(config)
+ self.dropout = config.dropout
+
+ embed_dim = config.d_model
+ self.num_mel_bins = config.num_mel_bins
+ self.max_source_positions = config.max_source_positions
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+ self.n_window = config.n_window
+ self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
+ self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
+ self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim)
+ self.audio_bos_eos_token = nn.Embedding(2, config.output_dim)
+ self.layers = nn.ModuleList([Qwen2_5OmniAudioEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self.ln_post = nn.LayerNorm(config.d_model)
+ self.avg_pooler = nn.AvgPool1d(2, stride=2)
+ self.proj = nn.Linear(config.d_model, config.output_dim)
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _freeze_parameters(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self._requires_grad = False
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.conv1
+
+ def set_input_embeddings(self, value: nn.Module):
+ self.conv1 = value
+
+ def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
+ # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
+ # NOTE: the created attention masl only approximates the ragged FA2 attention by
+ # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
+ # blocks. Though it will not be a 100% match for FA2's `varlen` path
+ if self.config._attn_implementation == "flash_attention_2":
+ return None
+
+ seq_length = inputs_tensor.shape[0]
+ attention_mask = torch.full(
+ [1, 1, seq_length, seq_length],
+ torch.finfo(inputs_tensor.dtype).min,
+ device=inputs_tensor.device,
+ dtype=inputs_tensor.dtype,
+ )
+ for i in range(1, len(cu_seqlens)):
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
+ return attention_mask
+
+ @auto_docstring
+ def forward(
+ self,
+ input_features,
+ feature_lens=None,
+ aftercnn_lens=None,
+ **kwargs,
+ ):
+ r"""
+ feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
+ mel length
+ aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`):
+ mel length after cnn
+ """
+ chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
+
+ chunk_lengths = torch.tensor(
+ [self.n_window * 2] * chunk_num.sum(),
+ dtype=torch.long,
+ device=feature_lens.device,
+ )
+ tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
+ chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
+ chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths)
+
+ chunk_list = input_features.split(chunk_lengths.tolist(), dim=1)
+ padded_feature, padded_mask, padded_mask_after_cnn = self.padded_and_mask_function(
+ chunk_list, chunk_lengths, padding_value=0, padding_side="right"
+ )
+ padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask
+ padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2)
+
+ padded_embed = padded_embed + self.positional_embedding.positional_embedding[
+ : padded_embed.shape[1], :
+ ].unsqueeze(0).to(padded_embed.dtype)
+ hidden_states = padded_embed[padded_mask_after_cnn]
+ cu_seqlens = torch.cat(
+ (
+ torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32),
+ padded_mask_after_cnn.sum(1).cumsum(0),
+ )
+ ).to(torch.int32)
+ attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
+
+ for encoder_layer in self.layers:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
+ token_audio_list = []
+ for each_audio_states in hidden_states_list:
+ each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1)
+ each_audio_states = self.ln_post(each_audio_states)
+ each_audio_states = self.proj(each_audio_states)
+ token_audio_list.append(each_audio_states)
+ token_audio = torch.cat(token_audio_list, dim=0)
+ return BaseModelOutput(last_hidden_state=token_audio)
+
+ def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"):
+ """
+ Pads a sequence of tensors to their maximum length on indicated `padding_side`.
+ Then prepares a mask so that pad tokens are not attended to.
+ """
+ max_len = tensor_len.max()
+ dim = tensor_list[0].shape[0]
+ padded_tensor = torch.full(
+ size=(len(tensor_list), dim, max_len),
+ fill_value=padding_value,
+ dtype=self.dtype,
+ device=tensor_list[0].device,
+ )
+
+ batch_mask = torch.zeros(
+ (len(tensor_len), max_len),
+ dtype=torch.long,
+ device=padded_tensor.device,
+ )
+ for i, length in enumerate(tensor_len):
+ batch_mask[i, :length] = 1
+ padded_tensor[i, :, :length] = tensor_list[i]
+
+ feature_lens_after_cnn = (tensor_len - 1) // 2 + 1
+ max_len_after_cnn = feature_lens_after_cnn.max()
+ batch_mask_after_cnn = torch.zeros(
+ (len(tensor_len), max_len_after_cnn),
+ dtype=torch.long,
+ device=padded_tensor.device,
+ )
+ for i, length in enumerate(feature_lens_after_cnn):
+ batch_mask_after_cnn[i, :length] = 1
+ return (
+ padded_tensor,
+ batch_mask.unsqueeze(1),
+ batch_mask_after_cnn.bool(),
+ )
+
+ # Ignore copy
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
+ """
+ Computes the output length of the convolutional layers and the output length of the audio encoder
+ """
+ input_lengths = (input_lengths - 1) // 2 + 1
+ output_lengths = (input_lengths - 2) // 2 + 1
+ return input_lengths, output_lengths
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
+ orig_dtype = tensor.dtype
+ tensor = tensor.float()
+ cos = freqs.cos()
+ sin = freqs.sin()
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
+ output = output.to(orig_dtype)
+ return output
+
+
+class Qwen2_5OmniVisionAttention(nn.Module):
+ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None:
+ super().__init__()
+ self.dim = config.hidden_size
+ self.num_heads = config.num_heads
+ self.head_dim = self.dim // self.num_heads
+ self.q = nn.Linear(self.dim, self.dim, bias=True)
+ self.k = nn.Linear(self.dim, self.dim, bias=True)
+ self.v = nn.Linear(self.dim, self.dim, bias=True)
+ self.proj = nn.Linear(self.dim, self.dim)
+ self.scaling = self.head_dim**-0.5
+ self.num_key_value_groups = 1 # needed for eager attention
+ self.config = config
+ self.attention_dropout = 0.0
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ query_states = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
+ key_states = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
+ value_states = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
+
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ if self.config._attn_implementation == "flash_attention_2":
+ # Flash Attention 2: Use cu_seqlens for variable length attention
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+ attn_output, _ = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ cu_seq_lens_q=cu_seqlens,
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+ else:
+ # Other implementations: Process each chunk separately
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+ splits = [
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
+ ]
+
+ attn_outputs = [
+ attention_interface(
+ self,
+ q,
+ k,
+ v,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ is_causal=False,
+ **kwargs,
+ )[0]
+ for q, k, v in zip(*splits)
+ ]
+ attn_output = torch.cat(attn_outputs, dim=1)
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Qwen2_5OmniMLP(nn.Module):
+ def __init__(self, config, bias: bool = False):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer):
+ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None:
+ super().__init__()
+ self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
+ self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
+ self.attn = Qwen2_5OmniVisionAttention(config=config)
+ self.mlp = Qwen2_5OmniMLP(config, bias=True)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ **kwargs,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Qwen2_5_VisionPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 14,
+ temporal_patch_size: int = 2,
+ in_channels: int = 3,
+ embed_dim: int = 1152,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.in_channels = in_channels
+ self.embed_dim = embed_dim
+
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class Qwen2_5_VisionRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+class Qwen2_5OmniPatchMerger(nn.Module):
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
+ super().__init__()
+ self.hidden_size = context_dim * (spatial_merge_size**2)
+ self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
+ self.mlp = nn.Sequential(
+ nn.Linear(self.hidden_size, self.hidden_size),
+ nn.GELU(),
+ nn.Linear(self.hidden_size, dim),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
+ return x
+
+
+class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
+ config: Qwen2_5OmniVisionEncoderConfig
+ _no_split_modules = ["Qwen2_5OmniVisionBlock"]
+
+ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.spatial_merge_size = config.spatial_merge_size
+ self.patch_size = config.patch_size
+ self.fullatt_block_indexes = config.fullatt_block_indexes
+ self.window_size = config.window_size
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
+
+ self.patch_embed = Qwen2_5_VisionPatchEmbed(
+ patch_size=config.patch_size,
+ temporal_patch_size=config.temporal_patch_size,
+ in_channels=config.in_channels,
+ embed_dim=config.hidden_size,
+ )
+
+ head_dim = config.hidden_size // config.num_heads
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
+ self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)])
+ self.merger = Qwen2_5OmniPatchMerger(
+ dim=config.out_hidden_size,
+ context_dim=config.hidden_size,
+ spatial_merge_size=config.spatial_merge_size,
+ )
+ self.gradient_checkpointing = False
+
+ def rot_pos_emb(self, grid_thw):
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb
+
+ def get_window_index(self, grid_thw):
+ window_index: list = []
+ cu_window_seqlens: list = [0]
+ window_index_id = 0
+ vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
+
+ for grid_t, grid_h, grid_w in grid_thw:
+ llm_grid_h, llm_grid_w = (
+ grid_h // self.spatial_merge_size,
+ grid_w // self.spatial_merge_size,
+ )
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
+ index_padded = index_padded.reshape(
+ grid_t,
+ num_windows_h,
+ vit_merger_window_size,
+ num_windows_w,
+ vit_merger_window_size,
+ )
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
+ grid_t,
+ num_windows_h * num_windows_w,
+ vit_merger_window_size,
+ vit_merger_window_size,
+ )
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
+ index_padded = index_padded.reshape(-1)
+ index_new = index_padded[index_padded != -100]
+ window_index.append(index_new + window_index_id)
+ cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
+ window_index = torch.cat(window_index, dim=0)
+
+ return window_index, cu_window_seqlens
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
+ The final hidden states of the model.
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Returns:
+ `torch.Tensor`: hidden_states.
+ """
+ hidden_states = self.patch_embed(hidden_states)
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
+ cu_window_seqlens = torch.tensor(
+ cu_window_seqlens,
+ device=hidden_states.device,
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
+
+ seq_len, _ = hidden_states.size()
+ hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ hidden_states = hidden_states[window_index, :, :]
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ # Modification here
+ for layer_num, blk in enumerate(self.blocks):
+ if layer_num in self.fullatt_block_indexes:
+ cu_seqlens_now = cu_seqlens
+ else:
+ cu_seqlens_now = cu_window_seqlens
+
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens_now,
+ rotary_pos_emb=rotary_pos_emb,
+ **kwargs,
+ )
+ hidden_states = self.merger(hidden_states)
+ reverse_indices = torch.argsort(window_index)
+ hidden_states = hidden_states[reverse_indices, :]
+
+ return hidden_states
+
+
+class Qwen2_5OmniRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Qwen2_5OmniThinkerConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, Qwen2_5Omni has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
+
+ Explanation:
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
+ difference with modern LLMs.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ mrope_section(`List(int)`):
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ mrope_section = mrope_section * 2
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class Qwen2_5OmniAttention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
+ def __init__(self, config: Qwen2_5OmniConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+ self.rope_scaling = config.rope_scaling
+ self.scaling = self.head_dim**-0.5
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+
+ self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
+ )
+
+ if past_key_values is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ position_ids=position_ids, # pass positions for FA2
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Qwen2MLP(nn.Module):
+ def __init__(self, config, bias: bool = False):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+class Qwen2_5OmniDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Qwen2_5OmniTextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
+ logger.warning_once(
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
+ "unexpected results may be encountered."
+ )
+ self.self_attn = Qwen2_5OmniAttention(config, layer_idx)
+
+ self.mlp = Qwen2MLP(config)
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.attention_type = config.layer_types[layer_idx]
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
+ config: Qwen2_5OmniTextConfig
+ _no_split_modules = ["Qwen2_5OmniDecoderLayer"]
+
+ def __init__(self, config: Qwen2_5OmniTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Qwen2_5OmniDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config)
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # torch.jit.trace() doesn't support cache objects in the output
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # the hard coded `3` is for temporal, height and width.
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
+ # where each dim indicates visual spatial positions for temporal/height/width grids.
+ # There are two scenarios when FA2-like packed masking might be activated.
+ # 1. User specifically passed packed `position_ids` and no attention mask.
+ # In this case we expect the useer to create correct position ids for all 3 grids
+ # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
+ # 2. User runs forward with no attention mask and no position ids. In this case, position ids
+ # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
+ # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
+ # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
+ text_position_ids = position_ids[0]
+ position_ids = position_ids[1:]
+ else:
+ # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids
+ text_position_ids = None
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ # Prepare mask arguments
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": text_position_ids,
+ }
+ # Create the masks
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ }
+ # The sliding window alternating layers are not always activated depending on the config
+ if self.has_sliding_layers:
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ position_ids=text_position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Qwen2.5OmniThinker model which consists of a audio backbone and a language model.
+ """
+)
+class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin):
+ config: Qwen2_5OmniThinkerConfig
+ base_model_prefix = "thinker"
+ _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
+ _no_split_modules = ["Qwen2_5OmniAudioEncoder", "Qwen2_5OmniVisionEncoder"]
+
+ def __init__(self, config: Qwen2_5OmniThinkerConfig):
+ super().__init__(config)
+ self.audio_tower = Qwen2_5OmniAudioEncoder._from_config(config.audio_config)
+ self.visual = Qwen2_5OmniVisionEncoder._from_config(config.vision_config)
+ self.vocab_size = config.text_config.vocab_size
+ self.model = Qwen2_5OmniThinkerTextModel._from_config(config.text_config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ self.spatial_merge_size = config.vision_config.spatial_merge_size
+ self.rope_deltas = None
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ """
+ Encodes videos into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
+ return video_embeds
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ pixel_values = pixel_values.type(self.visual.dtype)
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
+ return image_embeds
+
+ def get_audio_features(
+ self,
+ input_features: torch.FloatTensor,
+ feature_attention_mask: Optional[torch.LongTensor] = None,
+ audio_feature_lengths: Optional[torch.LongTensor] = None,
+ ):
+ """
+ Encodes audios into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ input_features (`torch.FloatTensor`):
+ The tensors corresponding to the input audios.
+ feature_attention_mask (`torch.LongTensor`, *optional*):
+ Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
+ audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
+ The length of feature shape of each audio in LLM.
+ """
+ if feature_attention_mask is not None:
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
+ input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
+ else:
+ audio_feature_lengths = None
+
+ audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
+ audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
+ )
+ feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
+ audio_outputs = self.audio_tower(
+ input_features,
+ feature_lens=feature_lens,
+ aftercnn_lens=audio_feat_lengths,
+ )
+ audio_features = audio_outputs.last_hidden_state
+
+ if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
+ raise ValueError("length of audio_features should match audio_output_lengths")
+
+ return audio_features
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: Optional[torch.FloatTensor] = None,
+ video_features: Optional[torch.FloatTensor] = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ special_audio_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ ).all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.video_token_id
+ special_audio_mask = input_ids == self.config.audio_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
+ raise ValueError(
+ f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
+ )
+
+ special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ return special_image_mask, special_video_mask, special_audio_mask
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ input_features: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ feature_attention_mask: Optional[torch.Tensor] = None,
+ audio_feature_lengths: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ use_audio_in_video: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ video_second_per_grid: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
+ The length of feature shape of each audio in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ use_audio_in_video (`bool`, *optional*):
+ Whether or not use audio track in video, should same as the parameter in `process_audio_info`.
+ video_second_per_grid (`torch.LongTensor` of shape `(num_videos)`, *optional*):
+ Number of seconds per grid for each video, used for temporal feature mapping.
+
+ Example:
+
+ ```python
+ >>> from io import BytesIO
+ >>> from urllib.request import urlopen
+ >>> import librosa
+ >>> from qwen_vl_utils import process_vision_info
+ >>> from transformers import Qwen2_5OmniProcessor, Qwen2_5OmniThinkerForConditionalGeneration
+
+ >>> thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-7B")
+ >>> processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
+
+ >>> conversations = [
+ >>> {'role': 'system', 'content': 'You are a helpful voice chat bot, and please respond to me in a casual conversation manner using random voice.'},
+ >>> {"role": "user", "content": [
+ >>> {"type": "image", "image_url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
+ >>> {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"},
+ >>> ]},
+ >>> ]
+
+ >>> text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
+ >>> audios = [ librosa.load(BytesIO(urlopen( conversations[1]['content'][1]['audio_url'] ).read()), sr=self.processor.feature_extractor.sampling_rate) ]
+ >>> images, videos = process_vision_info(conversations)
+ >>> inputs = processor(text=text, audios=audios, images=images, videos=videos, return_tensors="pt", padding=True)
+
+ >>> # Generate
+ >>> inputs['use_audio_in_video'] = `True` or `False`
+ >>> generation = thinker.generate(**inputs, max_new_tokens=2048)
+ >>> generate_ids = generation[:, inputs.input_ids.size(1):]
+
+ >>> response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is None:
+ # 1. Extract the input embeddings
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # 2. Merge text , audios , image and video
+ if input_features is not None:
+ audio_features = self.get_audio_features(
+ input_features,
+ feature_attention_mask=feature_attention_mask,
+ audio_feature_lengths=audio_feature_lengths,
+ )
+ audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
+ inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
+
+ if pixel_values is not None:
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw)
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask, _, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
+ _, video_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ if feature_attention_mask is not None:
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
+ else:
+ audio_feature_lengths = None
+
+ if attention_mask is not None and position_ids is None:
+ if (
+ cache_position is None
+ or (cache_position is not None and cache_position[0] == 0)
+ or self.rope_deltas is None
+ ):
+ delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ video_grid_thw,
+ attention_mask,
+ use_audio_in_video,
+ audio_feature_lengths,
+ video_second_per_grid,
+ )
+ rope_deltas = rope_deltas - delta0
+ self.rope_deltas = rope_deltas
+ else:
+ batch_size, seq_length = input_ids.shape
+ delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
+ position_ids = torch.arange(seq_length, device=input_ids.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ outputs = self.model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs
+ return (loss,) + output if loss is not None else output
+
+ return Qwen2_5OmniThinkerCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ input_features=None,
+ feature_attention_mask=None,
+ use_audio_in_video=False,
+ video_second_per_grid=None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ use_cache=use_cache,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ input_features=input_features,
+ feature_attention_mask=feature_attention_mask,
+ use_audio_in_video=use_audio_in_video,
+ video_second_per_grid=video_second_per_grid,
+ **kwargs,
+ )
+
+ model_inputs["position_ids"] = None
+
+ if cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_values_videos"] = None
+ model_inputs["input_features"] = None
+
+ return model_inputs
+
+
+############################
+# Start Talker #
+############################
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Qwen2.5OmniTalker causal language model (or autoregressive) outputs.
+ """
+)
+class Qwen2_5OmniTalkerCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ thinker_reply_part (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Hidden states from the thinker model that are used as input for the talker model. These represent the encoded
+ response that the talker model will use to generate speech tokens.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+ thinker_reply_part: Optional[torch.FloatTensor] = None
+
+
+@auto_docstring
+class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
+ config: Qwen2_5OmniTalkerConfig
+ _no_split_modules = ["Qwen2_5OmniTalkerDecoderLayer"]
+
+ def __init__(self, config: Qwen2_5OmniTalkerConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Qwen2_5OmniDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config)
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # torch.jit.trace() doesn't support cache objects in the output
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # the hard coded `3` is for temporal, height and width.
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
+ # where each dim indicates visual spatial positions for temporal/height/width grids.
+ # There are two scenarios when FA2-like packed masking might be activated.
+ # 1. User specifically passed packed `position_ids` and no attention mask.
+ # In this case we expect the useer to create correct position ids for all 3 grids
+ # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
+ # 2. User runs forward with no attention mask and no position ids. In this case, position ids
+ # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
+ # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
+ # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
+ text_position_ids = position_ids[0]
+ position_ids = position_ids[1:]
+ else:
+ # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids
+ text_position_ids = None
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ # Prepare mask arguments
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": text_position_ids,
+ }
+ # Create the masks
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ }
+ # The sliding window alternating layers are not always activated depending on the config
+ if self.has_sliding_layers:
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ position_ids=text_position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin):
+ config: Qwen2_5OmniTalkerConfig
+ base_model_prefix = "talker"
+
+ def __init__(self, config: Qwen2_5OmniTalkerConfig):
+ super().__init__(config)
+
+ self.thinker_to_talker_proj = nn.Linear(config.embedding_size, config.hidden_size)
+
+ self.model = Qwen2_5OmniTalkerModel(config)
+ self.codebook_size = config.vocab_size
+ self.codec_head = nn.Linear(config.hidden_size, self.codebook_size, bias=False)
+
+ self.codec_bos_token = config.tts_codec_start_token_id
+ self.codec_eos_token = config.tts_codec_end_token_id
+ self.codec_pad_token = config.tts_codec_pad_token_id
+ self.codec_mask_token = config.tts_codec_mask_token_id
+
+ self.text_bos_token = config.tts_text_start_token_id
+ self.text_eos_token = config.tts_text_end_token_id
+ self.text_pad_token = config.tts_text_pad_token_id
+
+ self.spatial_merge_size = self.config.spatial_merge_size
+ self.rope_deltas = None
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ thinker_reply_part: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ input_text_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ use_audio_in_video: Optional[bool] = None,
+ audio_feature_lengths: Optional[torch.LongTensor] = None,
+ video_second_per_grid: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, Qwen2_5OmniTalkerCausalLMOutputWithPast]:
+ r"""
+ thinker_reply_part (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Hidden states from the thinker model's output that represent the text reply part to be processed.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ input_text_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Input token IDs for text-only content, used for position calculation in multimodal contexts.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ use_audio_in_video (`bool`, *optional*):
+ Whether or not use audio track in video, should same as the parameter in `process_audio_info`.
+ audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
+ The length of feature shape of each audio in LLM.
+ video_second_per_grid (`torch.LongTensor` of shape `(num_videos)`, *optional*):
+ Number of seconds per grid for each video, used for temporal feature mapping.
+
+ Example:
+
+ ```python
+ >>> from io import BytesIO
+ >>> from urllib.request import urlopen
+ >>> import librosa
+ >>> from transformers import AutoProcessor, Qwen2_5OmniTalkerForConditionalGeneration
+
+ >>> model = Qwen2_5OmniTalkerForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B")
+
+ >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:"
+ >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
+ >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate)
+
+ >>> inputs = processor(text=prompt, audios=audio, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_length=30)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Generate the caption in English: Glass is breaking."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if attention_mask is not None and position_ids is None:
+ if (
+ cache_position is None
+ or (cache_position is not None and cache_position[0] == 0)
+ or self.rope_deltas is None
+ ):
+ position_ids, rope_deltas = self.get_rope_index(
+ input_text_ids,
+ image_grid_thw,
+ video_grid_thw,
+ attention_mask,
+ use_audio_in_video,
+ audio_feature_lengths,
+ video_second_per_grid,
+ )
+
+ inputs_embeds[:, -1, :] += self.get_input_embeddings()(
+ torch.tensor([self.codec_bos_token], dtype=torch.long, device=inputs_embeds.device)
+ )
+ inputs_embeds[:, -2, :] += self.get_input_embeddings()(
+ torch.tensor([self.codec_pad_token], dtype=torch.long, device=inputs_embeds.device)
+ )
+ self.rope_deltas = rope_deltas
+
+ else:
+ batch_size, seq_length = input_ids.shape
+ delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
+ position_ids = torch.arange(seq_length, device=input_ids.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ if inputs_embeds is None:
+ # 1. Inference tokens after second token
+ codec_embeds = self.get_input_embeddings()(input_ids)
+ inputs_embeds = codec_embeds + thinker_reply_part[:, :1, :]
+ if thinker_reply_part.shape[1] > 1:
+ thinker_reply_part = thinker_reply_part[:, 1:, :]
+
+ talker_lm_input = self.thinker_to_talker_proj(inputs_embeds)
+
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(inputs_embeds.device)
+
+ outputs = self.model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=talker_lm_input,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.codec_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return Qwen2_5OmniTalkerCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ thinker_reply_part=thinker_reply_part,
+ )
+
+ def _get_initial_cache_position(self, seq_length, device, model_kwargs):
+ # Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily
+ inputs_embeds = model_kwargs.pop("inputs_embeds")
+ model_kwargs = super()._get_initial_cache_position(seq_length, device, model_kwargs)
+ model_kwargs["inputs_embeds"] = inputs_embeds
+ return model_kwargs
+
+ # prepare inputs for talker lm generation
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ input_text_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ thinker_reply_part=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ input_audio_features=None,
+ audio_feature_attention_mask=None,
+ audio_feature_lengths=None,
+ use_audio_in_video=False,
+ video_second_per_grid=None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values,
+ attention_mask,
+ inputs_embeds,
+ cache_position,
+ use_cache=use_cache,
+ thinker_reply_part=thinker_reply_part,
+ input_text_ids=input_text_ids,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ use_audio_in_video=use_audio_in_video,
+ audio_feature_lengths=audio_feature_lengths,
+ video_second_per_grid=video_second_per_grid,
+ **kwargs,
+ )
+
+ model_inputs["position_ids"] = None
+
+ return model_inputs
+
+ def _update_model_kwargs_for_generation(
+ self,
+ outputs: ModelOutput,
+ model_kwargs: dict[str, Any],
+ is_encoder_decoder: bool = False,
+ num_new_tokens: int = 1,
+ ) -> dict[str, Any]:
+ model_kwargs = super()._update_model_kwargs_for_generation(
+ outputs, model_kwargs, is_encoder_decoder, num_new_tokens
+ )
+
+ if getattr(outputs, "thinker_reply_part", None) is not None:
+ model_kwargs["thinker_reply_part"] = outputs.thinker_reply_part
+
+ return model_kwargs
+
+
+############################
+# Start Token2Wav #
+############################
+
+
+# Using custom RoPE, will use LlamaRotaryEmbedding next version
+class Qwen2_5OmniDiTRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim, base=10000):
+ super().__init__()
+
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, x):
+ batch_size, seq_len = x.shape[0], x.shape[1]
+ t = torch.arange(seq_len, device=x.device)
+ device_type = x.device.type
+ device_type = device_type if device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float()
+ freqs = torch.stack((freqs, freqs), dim=-1)
+ freqs = freqs.reshape(*freqs.shape[:-2], -1)
+ freqs = freqs.repeat(batch_size, *([1] * freqs.dim()))
+ cos = freqs.cos()
+ sin = freqs.sin()
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class TimeDelayNetBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ dilation,
+ ):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ padding="same",
+ padding_mode="reflect",
+ )
+ self.activation = nn.ReLU()
+
+ def forward(self, hidden_states: torch.Tensor):
+ return self.activation(self.conv(hidden_states))
+
+
+class Res2NetBlock(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
+ super().__init__()
+
+ in_channel = in_channels // scale
+ hidden_channel = out_channels // scale
+
+ self.blocks = nn.ModuleList(
+ [
+ TimeDelayNetBlock(
+ in_channel,
+ hidden_channel,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ )
+ for i in range(scale - 1)
+ ]
+ )
+ self.scale = scale
+
+ def forward(self, hidden_states):
+ outputs = []
+ for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)):
+ if i == 0:
+ output_part = hidden_part
+ elif i == 1:
+ output_part = self.blocks[i - 1](hidden_part)
+ else:
+ output_part = self.blocks[i - 1](hidden_part + output_part)
+ outputs.append(output_part)
+ output = torch.cat(outputs, dim=1)
+ return output
+
+
+class SqueezeExcitationBlock(nn.Module):
+ def __init__(self, in_channels, se_channels, out_channels):
+ super().__init__()
+
+ self.conv1 = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=se_channels,
+ kernel_size=1,
+ padding="same",
+ padding_mode="reflect",
+ )
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv1d(
+ in_channels=se_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ padding="same",
+ padding_mode="reflect",
+ )
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, hidden_states):
+ hidden_states_mean = hidden_states.mean(dim=2, keepdim=True)
+
+ hidden_states_mean = self.relu(self.conv1(hidden_states_mean))
+ hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean))
+
+ return hidden_states * hidden_states_mean
+
+
+class AttentiveStatisticsPooling(nn.Module):
+ """This class implements an attentive statistic pooling layer for each channel.
+ It returns the concatenated mean and std of the input tensor.
+ """
+
+ def __init__(self, channels, attention_channels=128):
+ super().__init__()
+
+ self.eps = 1e-12
+ self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1)
+ self.tanh = nn.Tanh()
+ self.conv = nn.Conv1d(
+ in_channels=attention_channels,
+ out_channels=channels,
+ kernel_size=1,
+ padding="same",
+ padding_mode="reflect",
+ )
+
+ def _length_to_mask(self, length, max_len=None, dtype=None, device=None):
+ """Creates a binary mask for each sequence.
+
+ Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
+
+ Arguments
+ ---------
+ length : torch.LongTensor
+ Containing the length of each sequence in the batch. Must be 1D.
+ max_len : int
+ Max length for the mask, also the size of the second dimension.
+ dtype : torch.dtype, default: None
+ The dtype of the generated mask.
+ device: torch.device, default: None
+ The device to put the mask variable.
+
+ Returns
+ -------
+ mask : tensor
+ The binary mask.
+ """
+
+ if max_len is None:
+ max_len = length.max().long().item() # using arange to generate mask
+ mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
+ len(length), max_len
+ ) < length.unsqueeze(1)
+
+ mask = torch.as_tensor(mask, dtype=dtype, device=device)
+ return mask
+
+ def _compute_statistics(self, x, m, dim=2):
+ mean = (m * x).sum(dim)
+ std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps))
+ return mean, std
+
+ def forward(self, hidden_states):
+ seq_length = hidden_states.shape[-1]
+ lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device)
+
+ # Make binary mask of shape [N, 1, L]
+ mask = self._length_to_mask(
+ lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device
+ )
+ mask = mask.unsqueeze(1)
+
+ # Expand the temporal context of the pooling layer by allowing the
+ # self-attention to look at global properties of the utterance.
+ total = mask.sum(dim=2, keepdim=True)
+
+ mean, std = self._compute_statistics(hidden_states, mask / total)
+ mean = mean.unsqueeze(2).repeat(1, 1, seq_length)
+ std = std.unsqueeze(2).repeat(1, 1, seq_length)
+ attention = torch.cat([hidden_states, mean, std], dim=1)
+
+ # Apply layers
+ attention = self.conv(self.tanh(self.tdnn(attention)))
+
+ # Filter out zero-paddings
+ attention = attention.masked_fill(mask == 0, float("-inf"))
+
+ attention = F.softmax(attention, dim=2)
+ mean, std = self._compute_statistics(hidden_states, attention)
+ # Append mean and std of the batch
+ pooled_stats = torch.cat((mean, std), dim=1)
+ pooled_stats = pooled_stats.unsqueeze(2)
+
+ return pooled_stats
+
+
+class SqueezeExcitationRes2NetBlock(nn.Module):
+ """An implementation of building block in ECAPA-TDNN, i.e.,
+ TDNN-Res2Net-TDNN-SqueezeExcitationBlock.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ res2net_scale=8,
+ se_channels=128,
+ kernel_size=1,
+ dilation=1,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.tdnn1 = TimeDelayNetBlock(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ dilation=1,
+ )
+ self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation)
+ self.tdnn2 = TimeDelayNetBlock(
+ out_channels,
+ out_channels,
+ kernel_size=1,
+ dilation=1,
+ )
+ self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels)
+
+ def forward(self, hidden_state):
+ residual = hidden_state
+
+ hidden_state = self.tdnn1(hidden_state)
+ hidden_state = self.res2net_block(hidden_state)
+ hidden_state = self.tdnn2(hidden_state)
+ hidden_state = self.se_block(hidden_state)
+
+ return hidden_state + residual
+
+
+class ECAPA_TimeDelayNet(torch.nn.Module):
+ """An implementation of the speaker embedding model in a paper.
+ "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
+ TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143).
+ """
+
+ def __init__(self, config: Qwen2_5OmniDiTConfig):
+ super().__init__()
+ if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len(
+ config.enc_dilations
+ ):
+ raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length")
+ self.channels = config.enc_channels
+ self.blocks = nn.ModuleList()
+
+ # The initial TDNN layer
+ self.blocks.append(
+ TimeDelayNetBlock(
+ config.mel_dim,
+ config.enc_channels[0],
+ config.enc_kernel_sizes[0],
+ config.enc_dilations[0],
+ )
+ )
+
+ # SE-Res2Net layers
+ for i in range(1, len(config.enc_channels) - 1):
+ self.blocks.append(
+ SqueezeExcitationRes2NetBlock(
+ config.enc_channels[i - 1],
+ config.enc_channels[i],
+ res2net_scale=config.enc_res2net_scale,
+ se_channels=config.enc_se_channels,
+ kernel_size=config.enc_kernel_sizes[i],
+ dilation=config.enc_dilations[i],
+ )
+ )
+
+ # Multi-layer feature aggregation
+ self.mfa = TimeDelayNetBlock(
+ config.enc_channels[-1],
+ config.enc_channels[-1],
+ config.enc_kernel_sizes[-1],
+ config.enc_dilations[-1],
+ )
+
+ # Attentive Statistical Pooling
+ self.asp = AttentiveStatisticsPooling(
+ config.enc_channels[-1],
+ attention_channels=config.enc_attention_channels,
+ )
+
+ # Final linear transformation
+ self.fc = nn.Conv1d(
+ in_channels=config.enc_channels[-1] * 2,
+ out_channels=config.enc_dim,
+ kernel_size=1,
+ padding="same",
+ padding_mode="reflect",
+ )
+
+ def forward(self, hidden_states):
+ # Minimize transpose for efficiency
+ hidden_states = hidden_states.transpose(1, 2)
+
+ hidden_states_list = []
+ for layer in self.blocks:
+ hidden_states = layer(hidden_states)
+ hidden_states_list.append(hidden_states)
+
+ # Multi-layer feature aggregation
+ hidden_states = torch.cat(hidden_states_list[1:], dim=1)
+ hidden_states = self.mfa(hidden_states)
+
+ # Attentive Statistical Pooling
+ hidden_states = self.asp(hidden_states)
+
+ # Final linear transformation
+ hidden_states = self.fc(hidden_states)
+
+ hidden_states = hidden_states.squeeze(-1)
+ return hidden_states
+
+
+class DiTInputEmbedding(nn.Module):
+ def __init__(self, config: Qwen2_5OmniDiTConfig):
+ super().__init__()
+ self.proj = nn.Linear(
+ config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim,
+ config.hidden_size,
+ )
+ self.spk_encoder = ECAPA_TimeDelayNet(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ speaker_embedding: torch.Tensor,
+ condition_vector: torch.Tensor,
+ code_embed: torch.Tensor,
+ drop_audio_cond: Optional[bool] = False,
+ code_embed_uncond: Optional[bool] = None,
+ apply_cfg: Optional[bool] = True,
+ ):
+ if apply_cfg:
+ hidden_states = torch.cat([hidden_states, hidden_states], dim=0)
+ speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0)
+ condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0)
+ code_embed = torch.cat([code_embed, code_embed_uncond], dim=0)
+ elif drop_audio_cond: # cfg for cond audio
+ condition_vector = torch.zeros_like(condition_vector)
+ speaker_embedding = torch.zeros_like(speaker_embedding)
+ condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1)
+ hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1))
+
+ return hidden_states
+
+
+# Transformer backbone using DiT blocks
+class DiTCodecEmbedding(nn.Module):
+ def __init__(self, codec_num_embeds, codec_dim, repeats):
+ super().__init__()
+ self.repeats = repeats
+ self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim)
+
+ def forward(self, code, drop_code=False):
+ if drop_code:
+ code = torch.zeros_like(code)
+ code_embed = self.codec_embed(code)
+
+ code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1)
+ return code_embed
+
+
+# AdaLayerNormZero
+# return with modulated x for attn input, and params for later mlp modulation
+class Qwen2_5_OmniAdaLayerNormZero(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(dim, dim * 6)
+
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+
+ def forward(self, hidden_states, emb=None):
+ emb = self.linear(self.silu(emb))
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
+
+ hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+# AdaLayerNormZero for final layer
+# return only with modulated x for attn input, cuz no more mlp modulation
+class Qwen2_5_OmniAdaLayerNormZero_Final(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(dim, dim * 2)
+
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+
+ def forward(self, hidden_states, emb):
+ emb = self.linear(self.silu(emb))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+
+ hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return hidden_states
+
+
+# FeedForward
+class DiTMLP(nn.Module):
+ def __init__(self, dim, mult=4, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+
+ self.ff = nn.ModuleList(
+ [
+ nn.Linear(dim, inner_dim),
+ nn.GELU(approximate="tanh"),
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim),
+ ]
+ )
+
+ def forward(self, hidden_states):
+ for layer in self.ff:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+
+# Modified from Llama with a different rotate function, will fixed in next release
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+
+ def rotate_half_codec(x):
+ # x = rearrange(x, "... (d r) -> ... d r", r=2)
+ x = x.reshape(*x.shape[:-1], -1, 2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return x.reshape(*x.shape[:-2], -1)
+
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half_codec(q) * sin)
+ k_embed = (k * cos) + (rotate_half_codec(k) * sin)
+ return q_embed, k_embed
+
+
+class DiTAttention(nn.Module):
+ def __init__(self, config: Qwen2_5OmniDiTConfig):
+ super().__init__()
+
+ self.config = config
+ self.dim = config.hidden_size
+ self.heads = config.num_attention_heads
+ self.inner_dim = config.head_dim * config.num_attention_heads
+ self.dropout = config.dropout
+ self.is_causal = False
+
+ self.to_q = nn.Linear(config.hidden_size, self.inner_dim)
+ self.to_k = nn.Linear(config.hidden_size, self.inner_dim)
+ self.to_v = nn.Linear(config.hidden_size, self.inner_dim)
+
+ self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)])
+
+ def forward(
+ self,
+ hidden_states, # noised input x
+ position_embeddings=None, # rotary position embedding for x
+ attention_mask=None,
+ ) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ query = self.to_q(hidden_states)
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+
+ # attention
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // self.heads
+ query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+
+ # apply rotary position embedding
+ # Due to training process, only first head is applied with RoPE, will be fixed at next release
+ cos, sin = position_embeddings
+ query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin)
+
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ attention_weights, _ = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=attention_mask,
+ is_causal=False,
+ )
+
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
+ attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim)
+ attention_weights = attention_weights.to(query.dtype)
+
+ # linear proj
+ attention_output = self.to_out[0](attention_weights)
+ attention_output = self.to_out[1](attention_output)
+
+ return attention_output
+
+
+# time step conditioning embedding
+class SinusPositionEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, hidden_states, scale=1000):
+ device = hidden_states.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
+ emb = scale * hidden_states.unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb.type_as(hidden_states)
+
+
+class DiTTimestepEmbedding(nn.Module):
+ def __init__(self, dim, freq_embed_dim=256):
+ super().__init__()
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
+ self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)])
+
+ def forward(self, timestep):
+ time_hidden = self.time_embed(timestep)
+ time_hidden = time_hidden.to(timestep.dtype)
+ for layer in self.time_mlp:
+ time_hidden = layer(time_hidden) # b d
+ return time_hidden
+
+
+class DiTDecoderLayer(nn.Module):
+ def __init__(self, config: Qwen2_5OmniDiTConfig, look_ahead_block=0, look_backward_block=0):
+ super().__init__()
+ self.attn_norm = Qwen2_5_OmniAdaLayerNormZero(config.hidden_size)
+
+ self.attn = DiTAttention(config)
+ self.look_ahead_block = look_ahead_block
+ self.look_backward_block = look_backward_block
+ self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout)
+
+ def forward(
+ self, hidden_states, timestep, position_embeddings=None, block_diff=None
+ ): # x: noised input, t: time embedding
+ # pre-norm & modulation for attention input
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep)
+
+ # attention
+ attn_output = self.attn(
+ hidden_states=norm,
+ position_embeddings=position_embeddings,
+ attention_mask=(block_diff >= -float(self.look_backward_block))
+ & (block_diff <= float(self.look_ahead_block)),
+ )
+
+ # process attention output for input x
+ hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output
+
+ norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ ff_output = self.ff(norm)
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
+
+ return hidden_states
+
+
+class SnakeBeta(nn.Module):
+ """
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ References:
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://huggingface.co/papers/2006.08195
+ """
+
+ def __init__(self, in_features, alpha=1.0):
+ super().__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, hidden_states):
+ """
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
+ """
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
+ torch.sin(hidden_states * alpha), 2
+ )
+
+ return hidden_states
+
+
+def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
+ """Generates a 1D Kaiser-windowed sinc filter.
+
+ Args:
+ cutoff (float): Normalized cutoff frequency (0 to 0.5).
+ half_width (float): Transition bandwidth.
+ kernel_size (int): Number of filter taps.
+
+ Returns:
+ torch.Tensor: A tensor of shape (1, 1, kernel_size) representing the filter.
+ """
+ is_even = kernel_size % 2 == 0
+ half_size = kernel_size // 2
+
+ # Compute Kaiser window parameters
+ delta_f = 4 * half_width
+ attenuation = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
+
+ if attenuation > 50.0:
+ beta = 0.1102 * (attenuation - 8.7)
+ elif attenuation >= 21.0:
+ beta = 0.5842 * (attenuation - 21) ** 0.4 + 0.07886 * (attenuation - 21.0)
+ else:
+ beta = 0.0
+
+ kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32)
+
+ # Compute time indices
+ if is_even:
+ time_indices = torch.arange(-half_size, half_size) + 0.5
+ else:
+ time_indices = torch.arange(kernel_size) - half_size
+
+ # Compute sinc filter
+ if cutoff == 0:
+ return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape
+
+ sinc_filter = torch.sinc(2 * cutoff * time_indices)
+ normalized_filter = 2 * cutoff * kaiser_window * sinc_filter
+
+ # Normalize to ensure sum = 1 (avoid leakage of constant component)
+ normalized_filter /= normalized_filter.sum()
+
+ return normalized_filter.view(1, 1, kernel_size)
+
+
+class UpSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ self.stride = ratio
+ self.pad = self.kernel_size // ratio - 1
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
+
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
+ self.register_buffer("filter", filter, persistent=False)
+
+ def forward(self, hidden_states):
+ channels = hidden_states.shape[1]
+
+ hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate")
+ hidden_states = self.ratio * F.conv_transpose1d(
+ hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels
+ )
+ hidden_states = hidden_states[..., self.pad_left : -self.pad_right]
+
+ return hidden_states
+
+
+class DownSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ cutoff = 0.5 / ratio
+ half_width = 0.6 / ratio
+
+ if cutoff < 0.0:
+ raise ValueError("Minimum cutoff must be larger than zero.")
+ if cutoff > 0.5:
+ raise ValueError("A cutoff above 0.5 does not make sense.")
+
+ self.even = kernel_size % 2 == 0
+ self.pad_left = kernel_size // 2 - int(self.even)
+ self.pad_right = kernel_size // 2
+ self.stride = ratio
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
+ self.register_buffer("filter", filter, persistent=False)
+
+ def forward(self, hidden_states):
+ channels = hidden_states.shape[1]
+ hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate")
+ out = F.conv1d(hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels)
+ return out
+
+
+class TorchActivation1d(nn.Module):
+ def __init__(
+ self,
+ activation,
+ up_ratio: int = 2,
+ down_ratio: int = 2,
+ up_kernel_size: int = 12,
+ down_kernel_size: int = 12,
+ ):
+ super().__init__()
+ if not callable(activation):
+ raise TypeError("Activation function must be callable")
+ self.act = activation
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
+
+ def forward(self, hidden_states):
+ hidden_states = self.upsample(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.downsample(hidden_states)
+
+ return hidden_states
+
+
+class AMPBlock(torch.nn.Module):
+ def __init__(
+ self,
+ channels,
+ kernel_size=3,
+ dilation=(1, 3, 5),
+ ):
+ super().__init__()
+
+ self.convs1 = nn.ModuleList(
+ [
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=self._get_padding(kernel_size, dilation[0]),
+ ),
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=self._get_padding(kernel_size, dilation[1]),
+ ),
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=self._get_padding(kernel_size, dilation[2]),
+ ),
+ ]
+ )
+
+ self.convs2 = nn.ModuleList(
+ [
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self._get_padding(kernel_size, 1),
+ ),
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self._get_padding(kernel_size, 1),
+ ),
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self._get_padding(kernel_size, 1),
+ ),
+ ]
+ )
+
+ self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
+
+ self.activations = nn.ModuleList(
+ [TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)]
+ )
+
+ def _get_padding(self, kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+ def forward(self, hidden_states):
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
+ for conv1, conv2, act1, act2 in zip(self.convs1, self.convs2, acts1, acts2):
+ residual = hidden_states
+ hidden_states = act1(hidden_states)
+ hidden_states = conv1(hidden_states)
+ hidden_states = act2(hidden_states)
+ hidden_states = conv2(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+@auto_docstring(
+ custom_intro="""
+ The full Qwen2.5Omni Token2WavBigVGAN model. Which take mel spectrogram as input and predict waveform.
+ """
+)
+class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel):
+ config: Qwen2_5OmniBigVGANConfig
+
+ def __init__(self, config: Qwen2_5OmniBigVGANConfig):
+ super().__init__(config)
+ self.num_residual_blocks = len(config.resblock_kernel_sizes)
+ self.num_upsample_layers = len(config.upsample_rates)
+
+ self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 7, 1, padding=3)
+
+ # Removing extra ModuleList breaks official state dict
+ ups = [
+ nn.ModuleList(
+ [
+ nn.ConvTranspose1d(
+ config.upsample_initial_channel // (2**layer_idx),
+ config.upsample_initial_channel // (2 ** (layer_idx + 1)),
+ kernel_size,
+ stride,
+ padding=(kernel_size - stride) // 2,
+ )
+ ]
+ )
+ for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes))
+ ]
+ self.ups = nn.ModuleList(ups)
+
+ self.resblocks = nn.ModuleList(
+ [
+ AMPBlock(config.upsample_initial_channel // (2 ** (layer_idx + 1)), kernel_size, dilation)
+ for layer_idx in range(self.num_upsample_layers)
+ for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes)
+ ]
+ )
+
+ self.activation_post = TorchActivation1d(
+ activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers))
+ )
+ self.conv_post = nn.Conv1d(
+ config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False
+ )
+
+ def normalize_spectrogram(self, spectrogram, max_value, min_db):
+ return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value)
+
+ def amplitude_to_db(self, amplitude, min_db_level):
+ min_level = torch.exp(
+ torch.tensor(min_db_level / 20.0 * np.log(10), device=amplitude.device, dtype=amplitude.dtype)
+ )
+ return 20 * torch.log10(torch.clamp(amplitude, min=min_level))
+
+ def process_mel_spectrogram(self, mel_spectrogram):
+ amplitude_spectrum = torch.exp(mel_spectrogram)
+ decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20
+ return self.normalize_spectrogram(decibel_spectrum, 1, -115)
+
+ def forward(self, mel_spectrogram):
+ processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram)
+ hidden_representation = self.conv_pre(processed_spectrogram)
+
+ for layer_index in range(self.num_upsample_layers):
+ hidden_representation = self.ups[layer_index][0](hidden_representation)
+ residual_output = sum(
+ self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation)
+ for block_index in range(self.num_residual_blocks)
+ )
+ residual_output = residual_output / self.num_residual_blocks
+ hidden_representation = residual_output
+
+ hidden_representation = self.activation_post(hidden_representation)
+ output_waveform = self.conv_post(hidden_representation)
+ return torch.clamp(output_waveform, min=-1.0, max=1.0).squeeze().cpu()
+
+
+class RungeKutta4ODESolver:
+ def __init__(self, function, initial_value):
+ self.function = function
+ self.initial_value = initial_value
+
+ self._one_third = 1 / 3
+ self._two_thirds = 2 / 3
+
+ def _rk4_step(self, function, time_start, time_step, time_end, value_start, function_value_start=None):
+ k1 = function_value_start if function_value_start is not None else function(time_start, value_start)
+ k2 = function(time_start + time_step * self._one_third, value_start + time_step * k1 * self._one_third)
+ k3 = function(time_start + time_step * self._two_thirds, value_start + time_step * (k2 - k1 * self._one_third))
+ k4 = function(time_end, value_start + time_step * (k1 - k2 + k3))
+ return (k1 + 3 * (k2 + k3) + k4) * time_step / 8
+
+ def _compute_step(self, function, time_start, time_step, time_end, value_start):
+ function_value_start = function(time_start, value_start)
+ return self._rk4_step(
+ function, time_start, time_step, time_end, value_start, function_value_start=function_value_start
+ ), function_value_start
+
+ def _linear_interpolation(self, time_start, time_end, value_start, value_end, time_point):
+ if time_point == time_start:
+ return value_start
+ if time_point == time_end:
+ return value_end
+ weight = (time_point - time_start) / (time_end - time_start)
+ return value_start + weight * (value_end - value_start)
+
+ def integrate(self, time_points):
+ solution = torch.empty(
+ len(time_points),
+ *self.initial_value.shape,
+ dtype=self.initial_value.dtype,
+ device=self.initial_value.device,
+ )
+ solution[0] = self.initial_value
+
+ current_index = 1
+ current_value = self.initial_value
+ for time_start, time_end in zip(time_points[:-1], time_points[1:]):
+ time_step = time_end - time_start
+ delta_value, _ = self._compute_step(self.function, time_start, time_step, time_end, current_value)
+ next_value = current_value + delta_value
+
+ while current_index < len(time_points) and time_end >= time_points[current_index]:
+ solution[current_index] = self._linear_interpolation(
+ time_start, time_end, current_value, next_value, time_points[current_index]
+ )
+ current_index += 1
+
+ current_value = next_value
+
+ return solution
+
+
+@auto_docstring(
+ custom_intro="""
+ The full Qwen2.5Omni Token2WavDiT model. Which take speech tokens as input and predict mel spectrogram.
+ """
+)
+class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel):
+ config: Qwen2_5OmniDiTConfig
+ _no_split_modules = ["DiTDecoderLayer"]
+
+ def __init__(self, config: Qwen2_5OmniDiTConfig):
+ super().__init__(config)
+ self.mel_dim = config.mel_dim
+ self.repeats = config.repeats
+ self.time_embed = DiTTimestepEmbedding(config.hidden_size)
+
+ self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats)
+ self.input_embed = DiTInputEmbedding(config)
+
+ self.rotary_embed = Qwen2_5OmniDiTRotaryEmbedding(config.head_dim)
+
+ self.hidden_size = config.hidden_size
+ self.layers = config.num_hidden_layers
+ self.block_size = config.block_size
+ self.num_attention_heads = config.num_attention_heads
+
+ self.transformer_blocks = nn.ModuleList()
+ for i in range(config.num_hidden_layers):
+ self.transformer_blocks.append(
+ DiTDecoderLayer(
+ config,
+ look_ahead_block=1 if i in config.look_ahead_layers else 0,
+ look_backward_block=1 if i in config.look_backward_layers else 0,
+ )
+ )
+
+ self.norm_out = Qwen2_5_OmniAdaLayerNormZero_Final(config.hidden_size) # final modulation
+ self.proj_out = nn.Linear(config.hidden_size, config.mel_dim)
+
+ def _create_block_diff(self, hidden_states):
+ batch, seq_len = hidden_states.shape[0], hidden_states.shape[1]
+ block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length]
+
+ block_i = block_indices.unsqueeze(1) # [seq_length, 1]
+ block_j = block_indices.unsqueeze(0) # [1, seq_length]
+ block_diff = block_j - block_i # (n, n)
+
+ return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len)
+
+ def forward(
+ self,
+ hidden_states,
+ condition_vector,
+ speaker_embedding,
+ quantized_code,
+ time_step,
+ drop_audio_conditioning=False,
+ drop_code=False,
+ apply_cfg=True,
+ ):
+ batch_size = hidden_states.shape[0]
+ if time_step.ndim == 0:
+ time_step = time_step.repeat(batch_size)
+
+ # Compute embeddings
+ time_embedding = self.time_embed(time_step)
+ text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code)
+ text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None
+
+ hidden_states = self.input_embed(
+ hidden_states,
+ speaker_embedding,
+ condition_vector,
+ text_embedding,
+ drop_audio_cond=drop_audio_conditioning,
+ code_embed_uncond=text_embedding_unconditioned,
+ apply_cfg=apply_cfg,
+ )
+
+ # Compute positional encodings
+ position_embeddings = self.rotary_embed(hidden_states)
+ blockwise_difference = self._create_block_diff(hidden_states)
+
+ # Transformer blocks
+ for transformer_block in self.transformer_blocks:
+ hidden_states = transformer_block(
+ hidden_states,
+ time_embedding,
+ position_embeddings=position_embeddings,
+ block_diff=blockwise_difference,
+ )
+
+ hidden_states = self.norm_out(hidden_states, time_embedding)
+ output = self.proj_out(hidden_states)
+
+ return output
+
+ @torch.no_grad()
+ def sample(
+ self,
+ conditioning_vector,
+ reference_mel_spectrogram,
+ quantized_code,
+ num_steps=10,
+ guidance_scale=0.5,
+ sway_coefficient=-1.0,
+ ):
+ noise_initialization = torch.randn([1, 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype)
+ maximum_duration = quantized_code.shape[1] * self.repeats
+ initial_state = noise_initialization[:, :maximum_duration].to(quantized_code.device)
+ batch_size = reference_mel_spectrogram.shape[0]
+ conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1)
+
+ if batch_size != 1:
+ raise ValueError("Only batch size = 1 is currently supported")
+
+ def ode_function(time_step, hidden_states):
+ if guidance_scale < 1e-5:
+ prediction = self(
+ hidden_states=hidden_states,
+ speaker_embedding=conditioning_vector,
+ condition_vector=reference_mel_spectrogram,
+ quantized_code=quantized_code,
+ time_step=time_step,
+ drop_audio_conditioning=False,
+ drop_code=False,
+ )
+ return prediction
+
+ model_output = self(
+ hidden_states=hidden_states,
+ quantized_code=quantized_code,
+ speaker_embedding=conditioning_vector,
+ condition_vector=reference_mel_spectrogram,
+ time_step=time_step,
+ apply_cfg=True,
+ )
+ guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0)
+ return guided_prediction + (guided_prediction - null_prediction) * guidance_scale
+
+ initial_time = 0
+ time_embedding = torch.linspace(
+ initial_time, 1, num_steps, device=quantized_code.device, dtype=conditioning_vector.dtype
+ )
+
+ if sway_coefficient is not None:
+ time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding)
+
+ ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state)
+ solution_trajectory = ode_solver.integrate(time_embedding)
+
+ generated_waveform = solution_trajectory[-1]
+ generated_mel_spectrogram = generated_waveform.permute(0, 2, 1)
+ return generated_mel_spectrogram
+
+
+@auto_docstring(
+ custom_intro="""
+ The full Qwen2.5Omni Token2Wav model. Consists a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform.
+ """
+)
+class Qwen2_5OmniToken2WavModel(Qwen2_5OmniPreTrainedModel):
+ config: Qwen2_5OmniToken2WavConfig
+ base_model_prefix = "model"
+ _no_split_modules = ["Qwen2_5OmniToken2WavDiTModel", "Qwen2_5OmniToken2WavBigVGANModel"]
+
+ def __init__(self, config: Qwen2_5OmniToken2WavConfig):
+ super().__init__(config)
+ attn_impl = config._attn_implementation
+ if config._attn_implementation == "flash_attention_2":
+ logger.warning_once(
+ "Qwen2_5OmniToken2WavModel must inference with fp32, but flash_attention_2 only supports fp16 and bf16, "
+ "attention implementation of Qwen2_5OmniToken2WavModel will fallback to sdpa."
+ )
+ attn_impl = "sdpa"
+ elif config._attn_implementation == "eager":
+ logger.warning_once(
+ "Qwen2_5OmniToken2WavModel does not support eager attention implementation, fall back to sdpa"
+ )
+ attn_impl = "sdpa"
+ self.code2wav_dit_model = Qwen2_5OmniToken2WavDiTModel._from_config(
+ config.dit_config, attn_implementation=attn_impl
+ )
+ self.code2wav_bigvgan_model = Qwen2_5OmniToken2WavBigVGANModel._from_config(
+ config.bigvgan_config, attn_implementation=attn_impl
+ )
+
+ def forward(
+ self,
+ code,
+ conditioning,
+ reference_mel,
+ num_steps=10,
+ guidance_scale=0.5,
+ sway_coefficient=-1.0,
+ **kwargs,
+ ):
+ """Generates a waveform from input code and conditioning parameters."""
+
+ mel_spectrogram = self.code2wav_dit_model.sample(
+ conditioning,
+ reference_mel,
+ code,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ sway_coefficient=sway_coefficient,
+ )
+
+ waveform = self.code2wav_bigvgan_model(mel_spectrogram)
+
+ return waveform
+
+
+############################
+# Start Qwen2.5Omni #
+############################
+
+
+@auto_docstring(
+ custom_intro="""
+ The full Qwen2.5Omni model, a multimodal model composed of 3 sub-models:
+ - [`Qwen2_5OmniThinkerForConditionalGeneration`]:
+ a causal auto-regressive transformer takes text, audio, image, video as input and predict text tokens.
+ - [`Qwen2_5OmniTalkerForConditionalGeneration`]:
+ a causal auto-regressive transformer takes thinker hidden states and response as input and predict speech tokens.
+ - [`Qwen2_5OmniToken2WavModel`]:
+ a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform.
+ """
+)
+class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, GenerationMixin):
+ config: Qwen2_5OmniConfig
+ _no_split_modules = [
+ "Qwen2_5OmniTalkerForConditionalGeneration",
+ "Qwen2_5OmniToken2WavModel",
+ ]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.thinker = Qwen2_5OmniThinkerForConditionalGeneration(config.thinker_config)
+
+ self.has_talker = config.enable_audio_output
+ self.speaker_map = {}
+ if config.enable_audio_output:
+ self.enable_talker()
+ self.post_init()
+
+ def enable_talker(self):
+ self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config)
+ self.token2wav = Qwen2_5OmniToken2WavModel(self.config.token2wav_config)
+ self.token2wav.float()
+ self.has_talker = True
+
+ def load_speakers(self, path):
+ check_torch_load_is_safe()
+ for key, value in torch.load(path, weights_only=True).items():
+ self.speaker_map[key] = value
+ logger.info(f"Speaker {list(self.speaker_map.keys())} loaded")
+
+ def disable_talker(self):
+ if hasattr(self, "talker"):
+ del self.talker
+ if hasattr(self, "token2wav"):
+ del self.token2wav
+ self.has_talker = False
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path,
+ *model_args,
+ config=None,
+ cache_dir=None,
+ ignore_mismatched_sizes=False,
+ force_download=False,
+ local_files_only=False,
+ token=None,
+ revision="main",
+ use_safetensors=None,
+ weights_only=True,
+ **kwargs,
+ ):
+ model = super().from_pretrained(
+ pretrained_model_name_or_path,
+ *model_args,
+ config=config,
+ cache_dir=cache_dir,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ force_download=force_download,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ use_safetensors=use_safetensors,
+ weights_only=weights_only,
+ **kwargs,
+ )
+ spk_path = cached_file(
+ pretrained_model_name_or_path,
+ "spk_dict.pt",
+ subfolder=kwargs.pop("subfolder", None),
+ cache_dir=kwargs.pop("cache_dir", None),
+ force_download=kwargs.pop("force_download", False),
+ proxies=kwargs.pop("proxies", None),
+ resume_download=kwargs.pop("resume_download", None),
+ local_files_only=kwargs.pop("local_files_only", False),
+ token=kwargs.pop("use_auth_token", None),
+ revision=kwargs.pop("revision", None),
+ )
+ if spk_path is None:
+ raise ValueError(f"""{pretrained_model_name_or_path}/{spk_path} not exists""")
+ model.load_speakers(spk_path)
+
+ return model
+
+ @torch.no_grad()
+ # TODO: raushan, defaults should be saved in generation config
+ def generate(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ speaker: str = "Chelsie",
+ use_audio_in_video: bool = False,
+ return_audio: Optional[bool] = None,
+ thinker_max_new_tokens: int = 1024,
+ talker_max_new_tokens: int = 4096,
+ talker_do_sample: bool = True,
+ talker_top_k: int = 40,
+ talker_top_p: float = 0.8,
+ talker_temperature: float = 0.9,
+ talker_eos_token_id: list[int] = [8292, 8294],
+ talker_repetition_penalty: float = 1.05,
+ **kwargs,
+ ):
+ r"""
+ Generate text response and audio from input.
+
+ Args:
+ input_ids (`Optional[torch.Tensor]`, *optional*):
+ Input ids, should obtain from processor.
+ speaker (`str` , defaults to "Chelsie"):
+ Which speaker should be used in audio response.
+ use_audio_in_video (`bool`, defaults to False):
+ Whether or not use audio track in video, should same as the parameter in `process_audio_info`.
+ return_audio (`Optional[bool]`, *optional*):
+ Whether or not return response in audio format. When `return_audio=None`, this parameter is same as `config.enable_audio_output`.
+ kwargs (*optional*):
+ - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model.
+ - With a *thinker_*, *talker_*, *token2wav_* prefix, they will be input for the `generate` method of the
+ thinker, talker and token2wav respectively. It has the priority over the keywords without a prefix.
+ Returns:
+ When `return_audio=False`:
+ - **Text** (`torch.Tensor`): Generated text token sequence.
+ When `return_audio=True`:
+ - **Text** (`torch.Tensor`): Generated text token sequence.
+ - **Audio waveform** (`torch.Tensor`): Generated audio waveform.
+ """
+ if speaker not in self.speaker_map:
+ raise ValueError(f"{speaker} is not available, available speakers: {self.speaker_map.keys()}")
+ if return_audio and not self.has_talker:
+ raise ValueError(
+ "Cannot use talker when talker module not initialized. Use `enable_talker` method or set enable_talker in config to enable talker."
+ )
+ if return_audio is None:
+ return_audio = self.has_talker
+ if input_ids.shape[0] != 1 and return_audio:
+ raise NotImplementedError("Qwen2.5-Omni currently does not support batched inference with audio output")
+
+ shared_kwargs = {"use_audio_in_video": use_audio_in_video}
+ thinker_kwargs = {
+ "max_new_tokens": thinker_max_new_tokens,
+ }
+ talker_kwargs = {
+ "max_new_tokens": talker_max_new_tokens,
+ "do_sample": talker_do_sample,
+ "top_k": talker_top_k,
+ "top_p": talker_top_p,
+ "temperature": talker_temperature,
+ "eos_token_id": talker_eos_token_id,
+ "repetition_penalty": talker_repetition_penalty,
+ }
+ token2wav_kwargs = {}
+
+ for key, value in kwargs.items():
+ if key.startswith("thinker_"):
+ thinker_kwargs[key[len("thinker_") :]] = value
+ elif key.startswith("talker_"):
+ talker_kwargs[key[len("talker_") :]] = value
+ elif key.startswith("token2wav_"):
+ token2wav_kwargs[key[len("token2wav_") :]] = value
+ # Process special input values
+ elif key == "feature_attention_mask":
+ thinker_kwargs[key] = value
+ talker_kwargs["audio_feature_lengths"] = torch.sum(value, dim=1)
+ elif key == "input_features" or key == "attention_mask":
+ thinker_kwargs[key] = value
+ # Put other key to shared kwargs
+ else:
+ shared_kwargs[key] = value
+
+ # Merge kwargs
+ for key, value in shared_kwargs.items():
+ if key not in thinker_kwargs:
+ thinker_kwargs[key] = value
+ if key not in talker_kwargs:
+ talker_kwargs[key] = value
+ if key not in token2wav_kwargs:
+ token2wav_kwargs[key] = value
+ speaker_params = self.speaker_map[speaker]
+
+ # 1. Generate from thinker module
+ generate_audio = return_audio and self.has_talker
+ if generate_audio:
+ thinker_kwargs["output_hidden_states"] = True
+ thinker_kwargs["return_dict_in_generate"] = True
+
+ thinker_result = self.thinker.generate(input_ids=input_ids, **thinker_kwargs)
+
+ if not generate_audio:
+ return thinker_result
+
+ # 2. Generate speech tokens from talker module
+ embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(input_ids.device)
+ if thinker_kwargs.get("input_features") is not None:
+ audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index
+ audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
+ audio_mask_tensor = torch.zeros(
+ [audio_ids_mask.sum(), embeds_to_talker.shape[-1]],
+ dtype=embeds_to_talker.dtype,
+ device=input_ids.device,
+ )
+ embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor)
+ if thinker_kwargs.get("pixel_values") is not None:
+ image_ids_mask = input_ids == self.config.thinker_config.image_token_index
+ image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
+ image_mask_tensor = torch.zeros(
+ [image_ids_mask.sum(), embeds_to_talker.shape[-1]],
+ dtype=embeds_to_talker.dtype,
+ device=input_ids.device,
+ )
+ embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor)
+ if thinker_kwargs.get("pixel_values_videos") is not None:
+ video_ids_mask = input_ids == self.config.thinker_config.video_token_index
+ video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
+ video_mask_tensor = torch.zeros(
+ [video_ids_mask.sum(), embeds_to_talker.shape[-1]],
+ dtype=embeds_to_talker.dtype,
+ device=input_ids.device,
+ )
+ embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor)
+
+ processed_thinker_hidden = (
+ (embeds_to_talker,) + thinker_result.hidden_states[0][1:],
+ ) + thinker_result.hidden_states[1:]
+ thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(input_ids.device)
+ thinker_token_embeds = [
+ token_hidden_states[0].to(input_ids.device) for token_hidden_states in processed_thinker_hidden
+ ]
+ thinker_hidden_states = [
+ token_hidden_states[-1].to(input_ids.device) for token_hidden_states in processed_thinker_hidden
+ ]
+
+ talker_text_bos_token = speaker_params["bos_token"]
+ talker_input_text_ids = torch.cat(
+ [
+ input_ids,
+ torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device),
+ thinker_generate_ids[:, :1],
+ ],
+ dim=-1,
+ )
+
+ talker_input_ids = torch.cat(
+ [
+ torch.full_like(input_ids, fill_value=self.talker.codec_mask_token),
+ torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=input_ids.device),
+ torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=input_ids.device),
+ ],
+ dim=1,
+ )
+
+ thinker_embed_tokens = self.thinker.get_input_embeddings()
+ thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
+ talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0]
+ talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device)
+ talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(input_ids.device)
+ talker_inputs_embeds = torch.cat(
+ [
+ talker_inputs_embeds,
+ talker_text_bos_embed,
+ thinker_reply_part[:, :1, :],
+ ],
+ dim=1,
+ )
+
+ eos_token = torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=input_ids.device)
+ eos_embedding = thinker_embed_tokens(eos_token).to(input_ids.device)
+
+ pad_token = torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=input_ids.device)
+ pad_embedding = thinker_embed_tokens(pad_token).to(input_ids.device)
+
+ thinker_reply_part = torch.cat(
+ [
+ thinker_reply_part[:, 1:, :],
+ eos_embedding,
+ pad_embedding,
+ ],
+ dim=1,
+ )
+
+ talker_attention_mask = None
+ if "attention_mask" in kwargs:
+ talker_attention_mask = torch.cat(
+ [kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1
+ ).to(input_ids.device)
+
+ talker_result = self.talker.generate(
+ input_ids=talker_input_ids,
+ input_text_ids=talker_input_text_ids,
+ thinker_reply_part=thinker_reply_part,
+ inputs_embeds=talker_inputs_embeds,
+ attention_mask=talker_attention_mask,
+ suppress_tokens=[self.talker.codec_bos_token],
+ **{k: (v.to(input_ids.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()},
+ )
+ talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1]
+
+ # 3. Generate wavs from code
+ if self.token2wav.dtype != torch.float:
+ self.token2wav.float()
+
+ wav = self.token2wav(
+ talker_generate_codes.to(input_ids.device),
+ conditioning=speaker_params["cond"].to(input_ids.device).float(),
+ reference_mel=speaker_params["ref_mel"].to(input_ids.device).float(),
+ **token2wav_kwargs,
+ )
+
+ return thinker_result.sequences, wav.float()
+
+
+__all__ = [
+ "Qwen2_5OmniForConditionalGeneration",
+ "Qwen2_5OmniThinkerTextModel",
+ "Qwen2_5OmniThinkerForConditionalGeneration",
+ "Qwen2_5OmniTalkerModel",
+ "Qwen2_5OmniTalkerForConditionalGeneration",
+ "Qwen2_5OmniToken2WavDiTModel",
+ "Qwen2_5OmniToken2WavBigVGANModel",
+ "Qwen2_5OmniToken2WavModel",
+ "Qwen2_5OmniPreTrainedModel",
+ "Qwen2_5OmniPreTrainedModelForConditionalGeneration",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
new file mode 100644
index 0000000000000000000000000000000000000000..b63c301f36c3f1c9327ed8e73dcc946113a163bb
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
@@ -0,0 +1,4322 @@
+# coding=utf-8
+# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2.5Omni model (Audio, Image, Video)."""
+
+import math
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import Parameter
+
+from transformers.models.llama.modeling_llama import rotate_half
+from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
+from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
+ Qwen2_5_VisionTransformerPretrainedModel,
+ Qwen2_5_VLAttention,
+ Qwen2_5_VLMLP,
+ Qwen2_5_VLPreTrainedModel,
+ Qwen2_5_VLTextModel,
+ Qwen2_5_VLVisionBlock,
+ eager_attention_forward,
+)
+from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioEncoderConfig
+from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioEncoderLayer
+from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbedding
+
+from ...cache_utils import Cache
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...generation import GenerationMixin
+from ...modeling_outputs import BaseModelOutput, ModelOutput
+from ...modeling_rope_utils import rope_config_validation
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import (
+ TransformersKwargs,
+ auto_docstring,
+ check_torch_load_is_safe,
+ logging,
+)
+from ...utils.hub import cached_file
+
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen2_5OmniVisionEncoderConfig(Qwen2_5_VLVisionConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerVision`]. It is used to instantiate a
+ Qwen2.5-VL vision encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2.5-VL
+ architecture.
+
+ e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ depth (`int`, *optional*, defaults to 32):
+ Number of layers (depth) in the model.
+ hidden_size (`int`, *optional*, defaults to 3584):
+ The size of the hidden layers.
+ hidden_act (`str`, *optional*, defaults to `"quick_gelu"`):
+ The non-linear activation function used in the model. Supported options include `"quick_gelu"` and others as applicable.
+ mlp_ratio (`float`, *optional*, defaults to 4):
+ The ratio used to determine the size of the MLP (Multi-Layer Perceptron) hidden layer.
+ num_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer.
+ in_channels (`int`, *optional*, defaults to 3):
+ Number of input channels.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size of the patches extracted from the input.
+ spatial_merge_size (`int`, *optional*, defaults to 2):
+ The size used for merging spatial dimensions.
+ temporal_patch_size (`int`, *optional*, defaults to 2):
+ The size used for patches along the temporal dimension.
+
+ Example:
+
+ ```python
+ >>> from transformers import Qwen2_5OmniVisionEncoderConfig, Qwen2_5OmniVisionEncoder
+
+ >>> # Initializing a Qwen2_5OmniVisionEncoderConfig
+ >>> configuration = Qwen2_5OmniVisionEncoderConfig()
+
+ >>> # Initializing a Qwen2_5OmniVisionEncoder (with random weights)
+ >>> model = Qwen2_5OmniVisionEncoder(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen2_5_omni_vision_encoder"
+
+ def __init__(
+ self,
+ depth=32,
+ hidden_size=3584,
+ hidden_act="silu",
+ intermediate_size=3420,
+ num_heads=16,
+ in_channels=3,
+ patch_size=14,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ window_size=112,
+ out_hidden_size=3584,
+ fullatt_block_indexes=[7, 15, 23, 31],
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(
+ depth,
+ hidden_size,
+ hidden_act,
+ intermediate_size,
+ num_heads,
+ in_channels,
+ patch_size,
+ spatial_merge_size,
+ temporal_patch_size,
+ window_size,
+ out_hidden_size,
+ fullatt_block_indexes,
+ initializer_range=initializer_range,
+ **kwargs,
+ )
+ del self.tokens_per_second
+
+
+class Qwen2_5OmniAudioEncoderConfig(Qwen2AudioEncoderConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2_5OmniAudioEncoder`]. It is used to instantiate a
+ Qwen2.5-Omni-Thinker audio encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2-Audio
+ architecture.
+
+ e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_mel_bins (`int`, *optional*, defaults to 128):
+ Number of mel features used per input features. Should correspond to the value used in the
+ `Qwen2_5OmniProcessor` class.
+ encoder_layers (`int`, *optional*, defaults to 32):
+ Number of encoder layers.
+ encoder_attention_heads (`int`, *optional*, defaults to 20):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ encoder_ffn_dim (`int`, *optional*, defaults to 5120):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
+ d_model (`int`, *optional*, defaults to 1280):
+ Dimensionality of the layers.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_function (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ scale_embedding (`bool`, *optional*, defaults to `False`):
+ Scale embeddings by diving by sqrt(d_model).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ max_source_positions (`int`, *optional*, defaults to 1500):
+ The maximum sequence length of log-mel filter-bank features that this model might ever be used with.
+ n_window (`int`, *optional*, defaults to 100):
+ The chunk for conv and flash attn in AudioEncoder.
+ output_dim (`int`, *optional*, defaults to 3584):
+ The output dimension of AudioEncoder.
+
+ Example:
+
+ ```python
+ >>> from transformers import Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniAudioEncoder
+
+ >>> # Initializing a Qwen2_5OmniAudioEncoderConfig
+ >>> configuration = Qwen2_5OmniAudioEncoderConfig()
+
+ >>> # Initializing a Qwen2_5OmniAudioEncoder (with random weights)
+ >>> model = Qwen2_5OmniAudioEncoder(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen2_5_omni_audio_encoder"
+
+ def __init__(
+ self,
+ num_mel_bins=128,
+ encoder_layers=32,
+ encoder_attention_heads=20,
+ encoder_ffn_dim=5120,
+ d_model=1280,
+ dropout=0,
+ attention_dropout=0,
+ activation_function="gelu",
+ activation_dropout=0,
+ scale_embedding=False,
+ initializer_range=0.02,
+ max_source_positions=1500,
+ n_window=100,
+ output_dim=3584,
+ **kwargs,
+ ):
+ super().__init__(
+ num_mel_bins,
+ encoder_layers,
+ encoder_attention_heads,
+ encoder_ffn_dim,
+ d_model,
+ dropout,
+ attention_dropout,
+ activation_function,
+ activation_dropout,
+ scale_embedding,
+ initializer_range,
+ max_source_positions,
+ **kwargs,
+ )
+ self.n_window = n_window
+ self.output_dim = output_dim
+ del self.encoder_layerdrop
+
+
+class Qwen2_5OmniTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerForConditionalGeneration`]. It is used to instantiate an
+ Qwen2.5-Omni-Thinker model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker.
+
+ e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 152064):
+ Vocabulary size of the QwenOmni model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Qwen2VLModel`]
+ hidden_size (`int`, *optional*, defaults to 3584):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 18944):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 28):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 28):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 4):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
+ Whether to use sliding window attention.
+ sliding_window (`int`, *optional*, defaults to 32768):
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
+ max_window_layers (`int`, *optional*, defaults to 28):
+ The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
+ additional layer afterwards will use SWA (Sliding Window Attention).
+ layer_types (`list`, *optional*):
+ Attention pattern for each layer.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+ Example:
+
+ ```python
+ >>> from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig
+
+ >>> # Initializing a Qwen2_5OmniAudioEncoder config
+ >>> audio_config = Qwen2_5OmniAudioEncoderConfig()
+
+ >>> # Initializing a Qwen2_5OmniVisionEncoder config
+ >>> vision_config = Qwen2_5OmniVisionEncoderConfig()
+
+ >>> # Initializing a Qwen2.5OmniThinker configuration
+ >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, vision_config)
+
+ >>> # Initializing a model from the Qwen-Omni style configuration
+ >>> model = Qwen2_5OmniThinkerForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen2_5_omni_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ # Default tensor parallel plan for base model `Qwen25OmniText`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=152064,
+ hidden_size=3584,
+ intermediate_size=18944,
+ num_hidden_layers=28,
+ num_attention_heads=28,
+ num_key_value_heads=4,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=1000000.0,
+ rope_scaling=None,
+ use_sliding_window=False,
+ sliding_window=32768,
+ max_window_layers=28,
+ layer_types=None,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.use_sliding_window = use_sliding_window
+ self.sliding_window = sliding_window if self.use_sliding_window else None
+ self.max_window_layers = max_window_layers
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_dropout = attention_dropout
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+ if self.rope_scaling is None:
+ self.rope_scaling = {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"}
+
+ self.layer_types = layer_types
+ if self.layer_types is None:
+ self.layer_types = [
+ "sliding_attention"
+ if self.sliding_window is not None and i >= self.max_window_layers
+ else "full_attention"
+ for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+
+class Qwen2_5OmniThinkerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerForConditionalGeneration`]. It is used to instantiate an
+ Qwen2.5-Omni-Thinker model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker.
+
+ e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ audio_config (`dict`, *optional*):
+ The config dictionary of the audio backbone.
+ vision_config (`dict`, *optional*):
+ The config dictionary of the vision backbone.
+ text_config (`dict`, *optional*):
+ The config dictionary of the text backbone.
+ audio_token_index (`int`, *optional*, defaults to 151646):
+ The audio token index to encode the audio prompt.
+ image_token_index (`int`, *optional*, defaults to 151655):
+ The image token index to encode the image prompt.
+ video_token_index (`int`, *optional*, defaults to 151656):
+ The video token index to encode the video prompt.
+ position_id_per_seconds (`int`, *optional*, defaults to 25):
+ The increment of position id per second.
+ seconds_per_chunk (`int`, *optional*, defaults to 2):
+ The duration in seconds of the chunk of audio and video data.
+ audio_start_token_id (`int`, *optional*, defaults to 151647):
+ The audio start token index to encode the audio prompt.
+ audio_end_token_id (`int`, *optional*, defaults to 151648):
+ The audio end token index to encode the audio prompt.
+ user_token_id (`int, *optional*, defaults to 872):
+ The user token index to encode the user token.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+ Example:
+
+ ```python
+ >>> from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig
+
+ >>> # Initializing a Qwen2_5OmniAudioEncoder config
+ >>> audio_config = Qwen2_5OmniAudioEncoderConfig()
+
+ >>> # Initializing a Qwen2_5OmniVisionEncoder config
+ >>> vision_config = Qwen2_5OmniVisionEncoderConfig()
+
+ >>> # Initializing a Qwen2_5OmniTextConfig config
+ >>> text_config = Qwen2_5OmniTextConfig()
+
+ >>> # Initializing a Qwen2.5OmniThinker configuration
+ >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, vision_config, text_config)
+
+ >>> # Initializing a model from the Qwen-Omni style configuration
+ >>> model = Qwen2_5OmniThinkerForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen2_5_omni_thinker"
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ "video_token_id": "video_token_index",
+ "audio_token_id": "audio_token_index",
+ }
+ sub_configs = {
+ "audio_config": Qwen2_5OmniAudioEncoderConfig,
+ "vision_config": Qwen2_5OmniVisionEncoderConfig,
+ "text_config": Qwen2_5OmniTextConfig,
+ }
+
+ def __init__(
+ self,
+ audio_config=None,
+ vision_config=None,
+ text_config=None,
+ audio_token_index=151646,
+ image_token_index=151655,
+ video_token_index=151656,
+ position_id_per_seconds=25,
+ seconds_per_chunk=2,
+ audio_start_token_id=151647,
+ audio_end_token_id=151648,
+ user_token_id=872,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ self.audio_token_index = audio_token_index
+ self.image_token_index = image_token_index
+ self.video_token_index = video_token_index
+ self.user_token_id = user_token_id
+ self.position_id_per_seconds = position_id_per_seconds
+ self.seconds_per_chunk = seconds_per_chunk
+ self.audio_start_token_id = audio_start_token_id
+ self.audio_end_token_id = audio_end_token_id
+ self.initializer_range = initializer_range
+
+ if isinstance(vision_config, dict):
+ vision_config = Qwen2_5OmniVisionEncoderConfig(**vision_config)
+ elif vision_config is None:
+ vision_config = Qwen2_5OmniVisionEncoderConfig()
+ self.vision_config = vision_config
+
+ if isinstance(audio_config, dict):
+ audio_config = Qwen2_5OmniAudioEncoderConfig(**audio_config)
+ elif audio_config is None:
+ audio_config = Qwen2_5OmniAudioEncoderConfig()
+ self.audio_config = audio_config
+
+ if isinstance(text_config, dict):
+ text_config = Qwen2_5OmniTextConfig(**text_config)
+ elif text_config is None:
+ text_config = Qwen2_5OmniTextConfig()
+ self.text_config = text_config
+
+ super().__init__(**kwargs)
+
+
+class Qwen2_5OmniTalkerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2_5OmniTalkerForConditionalGeneration`]. It is used to instantiate an
+ Qwen2.5-Omni-Talker model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker.
+
+ e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ audio_token_index (`int`, *optional*, defaults to 151646):
+ The audio token index to encode the audio prompt.
+ image_token_index (`int`, *optional*, defaults to 151655):
+ The image token index to encode the image prompt.
+ video_token_index (`int`, *optional*, defaults to 151656):
+ The video token index to encode the video prompt.
+ vocab_size (`int`, *optional*, defaults to 8448):
+ Vocabulary size of the QwenOmni model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Qwen2VLModel`]
+ tts_text_start_token_id (`int`, *optional*, defaults to 151860):
+ The tts text start token index to encode the start of tts text.
+ tts_text_end_token_id (`int`, *optional*, defaults to 151861):
+ The tts text end token index to encode the end of tts text.
+ tts_text_pad_token_id (`int`, *optional*, defaults to 151859):
+ The tts text pad token index to encode the pad of tts text.
+ tts_codec_start_token_id (`int`, *optional*, defaults to 8293):
+ The tts codec start token index to encode the start of tts codec.
+ tts_codec_end_token_id (`int`, *optional*, defaults to 8294):
+ The tts codec end token index to encode the end of tts codec.
+ tts_codec_pad_token_id (`int`, *optional*, defaults to 8292):
+ The tts codec pad token index to encode the pad of tts codec.
+ tts_codec_mask_token_id (`int`, *optional*, defaults to 8296):
+ The tts codec mask token index to encode the mask of tts codec.
+ vision_start_token_id (`int`, *optional*, defaults to 151652):
+ The tts vision start token index to encode the start of vision.
+ vision_end_token_id (`int`, *optional*, defaults to 151653):
+ The tts vision end token index to encode the end of vision.
+ embedding_size (`int`, *optional*, defaults to 3584):
+ Dimension of the embedding representations.
+ hidden_size (`int`, *optional*, defaults to 3584):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 18944):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 28):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 28):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 4):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ head_dim (`int`, *optional*, defaults to 128):
+ The dimension of each attention head.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
+ Whether to use sliding window attention.
+ sliding_window (`int`, *optional*, defaults to 32768):
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
+ max_window_layers (`int`, *optional*, defaults to 28):
+ The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
+ additional layer afterwards will use SWA (Sliding Window Attention).
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ position_id_per_seconds (`int`, *optional*, defaults to 25):
+ The increment of position id per second.
+ seconds_per_chunk (`int`, *optional*, defaults to 2):
+ The duration in seconds of the chunk of audio and video data.
+ audio_start_token_id (`int`, *optional*, defaults to 151647):
+ The audio start token index to encode the audio prompt.
+ audio_end_token_id (`int`, *optional*, defaults to 151648):
+ The audio end token index to encode the audio prompt.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ spatial_merge_size (`int`, *optional*, defaults to 2):
+ The size used for merging spatial dimensions.
+ layer_types (`list`, *optional*):
+ Attention pattern for each layer.
+
+ Example:
+
+ ```python
+ >>> from transformers import Qwen2_5OmniTalkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig
+
+ >>> # Initializing a Qwen2_5OmniAudioEncoder config
+ >>> audio_config = Qwen2_5OmniAudioEncoderConfig()
+
+ >>> # Initializing a Qwen2 config
+ >>> text_config = Qwen2Config()
+
+ >>> # Initializing a Qwen2_5Omni configuration
+ >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, text_config)
+
+ >>> # Initializing a model from the qwen2-audio style configuration
+ >>> model = Qwen2_5OmniTalkerForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen2_5_omni_talker"
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ "video_token_id": "video_token_index",
+ "audio_token_id": "audio_token_index",
+ }
+
+ def __init__(
+ self,
+ audio_token_index=151646,
+ image_token_index=151655,
+ video_token_index=151656,
+ vocab_size=8448,
+ tts_text_start_token_id=151860,
+ tts_text_end_token_id=151861,
+ tts_text_pad_token_id=151859,
+ tts_codec_start_token_id=8293,
+ tts_codec_end_token_id=8294,
+ tts_codec_pad_token_id=8292,
+ tts_codec_mask_token_id=8296,
+ vision_start_token_id=151652,
+ vision_end_token_id=151653,
+ embedding_size=3584,
+ hidden_size=3584,
+ intermediate_size=18944,
+ num_hidden_layers=28,
+ num_attention_heads=28,
+ num_key_value_heads=4,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ rms_norm_eps=1e-06,
+ head_dim=128,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=1000000.0,
+ use_sliding_window=False,
+ sliding_window=32768,
+ max_window_layers=28,
+ attention_dropout=0.0,
+ rope_scaling=None,
+ position_id_per_seconds=25,
+ seconds_per_chunk=2,
+ audio_start_token_id=151647,
+ audio_end_token_id=151648,
+ initializer_range=0.02,
+ spatial_merge_size=2,
+ layer_types=None,
+ **kwargs,
+ ):
+ self.audio_token_index = audio_token_index
+ self.image_token_index = image_token_index
+ self.video_token_index = video_token_index
+
+ self.tts_text_start_token_id = tts_text_start_token_id
+ self.tts_text_end_token_id = tts_text_end_token_id
+ self.tts_text_pad_token_id = tts_text_pad_token_id
+ self.tts_codec_start_token_id = tts_codec_start_token_id
+ self.tts_codec_end_token_id = tts_codec_end_token_id
+ self.tts_codec_pad_token_id = tts_codec_pad_token_id
+
+ self.tts_codec_mask_token_id = tts_codec_mask_token_id
+
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+
+ self.vocab_size = vocab_size
+ self.head_dim = head_dim
+ self.embedding_size = embedding_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.use_sliding_window = use_sliding_window
+ self.sliding_window = sliding_window if self.use_sliding_window else None
+ self.max_window_layers = max_window_layers
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+ self.rope_scaling = rope_scaling
+ self.position_id_per_seconds = position_id_per_seconds # zf
+ self.seconds_per_chunk = seconds_per_chunk # zf
+ self.audio_start_token_id = audio_start_token_id # zf
+ self.audio_end_token_id = audio_end_token_id # zf
+
+ self.initializer_range = initializer_range
+ self.spatial_merge_size = spatial_merge_size
+
+ self.layer_types = layer_types
+ if self.layer_types is None:
+ self.layer_types = [
+ "sliding_attention"
+ if self.sliding_window is not None and i >= self.max_window_layers
+ else "full_attention"
+ for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+class Qwen2_5OmniDiTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of the Qwen2_5OmniToken2WavDiT used in the Qwen2.5-Omni-Token2Wav model.
+ It defines the architecture of the DiT model, which is used for generating mel-spectrograms from tokens.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 1024):
+ The dimension of the model.
+ num_hidden_layers (`int`, *optional*, defaults to 22):
+ The number of transformer blocks in the DiT model.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ The number of attention heads in each transformer block.
+ ff_mult (`int`, *optional*, defaults to 2):
+ The multiplier for the feedforward layer in each transformer block.
+ emb_dim (`int`, *optional*, defaults to 512):
+ The dimension of the embedding layer.
+ head_dim (`int`, *optional*, defaults to 64):
+ The dimension of each attention head.
+ repeats (`int`, *optional*, defaults to 2):
+ The number of times the codec embeddings are repeated.
+ num_embeds (`int`, *optional*, defaults to 8193):
+ The number of unique embeddings in the codec.
+ mel_dim (`int`, *optional*, defaults to 80):
+ The dimension of the mel-spectrogram.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout rate for the transformer blocks.
+
+ enc_emb_dim (`int`, *optional*, defaults to 192):
+ The dimension of the pre-trained speaker embedding.
+ enc_dim (`int`, *optional*, defaults to 128):
+ The dimension of the encoder output.
+ enc_channels (`list[int]`, *optional*, defaults to `[256, 256, 256, 256, 768]`):
+ A list of output channels for each TDNN/SERes2Net layer in the encoder.
+ enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
+ A list of kernel sizes for each layer in the encoder.
+ enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
+ A list of dilations for each layer in the encoder.
+ enc_attention_channels (`int`, *optional*, defaults to 64):
+ The number of attention channels in the SqueezeExcitationBlock.
+ enc_res2net_scale (`int`, *optional*, defaults to 2):
+ The scale of the Res2Net block in the encoder.
+ enc_se_channels (`int`, *optional*, defaults to 64):
+ The number of output channels after squeeze in the SqueezeExcitationBlock.
+ """
+
+ model_type = "qwen2_5_omni_dit"
+
+ def __init__(
+ self,
+ hidden_size=1024,
+ num_hidden_layers=22,
+ num_attention_heads=16,
+ ff_mult=2,
+ emb_dim=512,
+ head_dim=64,
+ rope_theta=10000.0,
+ max_position_embeddings=32768,
+ block_size=24,
+ look_ahead_layers=[10],
+ look_backward_layers=[0, 20],
+ repeats=2,
+ num_embeds=8193,
+ mel_dim=80,
+ dropout=0.1,
+ enc_emb_dim=192,
+ enc_dim=128,
+ enc_channels=[256, 256, 256, 256, 768],
+ enc_kernel_sizes=[5, 3, 3, 3, 1],
+ enc_dilations=[1, 2, 3, 4, 1],
+ enc_attention_channels=64,
+ enc_res2net_scale=2,
+ enc_se_channels=64,
+ **kwargs,
+ ):
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.ff_mult = ff_mult
+ self.emb_dim = emb_dim
+ self.head_dim = head_dim
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+ self.block_size = block_size
+ self.look_ahead_layers = look_ahead_layers
+ self.look_backward_layers = look_backward_layers
+ self.repeats = repeats
+ self.num_embeds = num_embeds
+ self.mel_dim = mel_dim
+ self.dropout = dropout
+ self.enc_emb_dim = enc_emb_dim
+ self.enc_dim = enc_dim
+ self.enc_channels = enc_channels
+ self.enc_kernel_sizes = enc_kernel_sizes
+ self.enc_dilations = enc_dilations
+ self.enc_attention_channels = enc_attention_channels
+ self.enc_res2net_scale = enc_res2net_scale
+ self.enc_se_channels = enc_se_channels
+ super().__init__(**kwargs)
+
+
+class Qwen2_5OmniBigVGANConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of the Qwen2_5OmniToken2WavBigVGAN module used in the Qwen2.5-Omni-Token2Wav model.
+ It defines the architecture of the BigVGAN model, which is used for converting mel-spectrograms to waveforms.
+
+ Args:
+ mel_dim (`int`, *optional*, defaults to 80):
+ The dimension of the mel-spectrogram.
+ upsample_initial_channel (`int`, *optional*, defaults to 1536):
+ The number of channels in the initial upsampling layer.
+ resblock_kernel_sizes (`list[int]`, *optional*, defaults to `[3, 7, 11]`):
+ A list of kernel sizes for each residual block.
+ resblock_dilation_sizes (`list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
+ A list of dilation sizes for each residual block.
+ upsample_rates (`list[int]`, *optional*, defaults to `[5, 3, 2, 2, 2, 2]`):
+ A list of upsampling rates for each upsampling layer.
+ upsample_kernel_sizes (`list[int]`, *optional*, defaults to `[11, 7, 4, 4, 4, 4]`):
+ A list of kernel sizes for each upsampling layer.
+ """
+
+ model_type = "qwen2_5_omni_bigvgan"
+
+ def __init__(
+ self,
+ mel_dim=80,
+ upsample_initial_channel=1536,
+ resblock_kernel_sizes=[3, 7, 11],
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ upsample_rates=[5, 3, 2, 2, 2, 2],
+ upsample_kernel_sizes=[11, 7, 4, 4, 4, 4],
+ **kwargs,
+ ):
+ self.mel_dim = mel_dim
+ self.upsample_initial_channel = upsample_initial_channel
+ self.resblock_kernel_sizes = resblock_kernel_sizes
+ self.resblock_dilation_sizes = resblock_dilation_sizes
+ self.upsample_rates = upsample_rates
+ self.upsample_kernel_sizes = upsample_kernel_sizes
+ super().__init__(**kwargs)
+
+
+class Qwen2_5OmniToken2WavConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2_5OmniToken2WavModel`].
+ It is used to instantiate the Qwen2.5-Omni-Token2Wav model which combines a Diffusion Transformer (DiT) for mel-spectrogram generation with a BigVGAN model for waveform synthesis. The configuration contains sub-configurations for both components.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ dit_config ([`DiT_Args`], *optional*):
+ Configuration class for the Diffusion Transformer (DiT) module responsible for generating mel-spectrograms.
+ bigvgan_config ([`BigVGAN_Args`], *optional*):
+ Configuration class for the BigVGAN module responsible for converting mel-spectrograms to waveforms.
+ Example:
+
+ ```python
+ >>> from transformers import Qwen2_5OmniToken2WavModel, DiT_Args, BigVGAN_Args
+
+ >>> # Initialize DiT configuration
+ >>> dit_config = DiT_Args(
+ ... dim=1024,
+ ... depth=22,
+ ... heads=16,
+ ... ff_mult=2
+ ... )
+
+ >>> # Initialize BigVGAN configuration
+ >>> bigvgan_config = BigVGAN_Args(
+ ... mel_dim=80,
+ ... upsample_rates=[5,3,2,2,2,2]
+ ... )
+
+ >>> # Initialize main configuration
+ >>> config = Qwen2_5OmniToken2WavConfig(dit_config, bigvgan_config)
+
+ >>> # Initialize model with config
+ >>> model = Qwen2_5OmniToken2Wav(config)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "qwen2_5_omni_token2wav"
+ sub_configs = {
+ "dit_config": Qwen2_5OmniDiTConfig,
+ "bigvgan_config": Qwen2_5OmniBigVGANConfig,
+ }
+
+ def __init__(self, dit_config=None, bigvgan_config=None, **kwargs):
+ if dit_config is None:
+ dit_config = {}
+ if bigvgan_config is None:
+ bigvgan_config = {}
+ self.dit_config = Qwen2_5OmniDiTConfig(**dit_config)
+ self.bigvgan_config = Qwen2_5OmniBigVGANConfig(**bigvgan_config)
+ super().__init__(**kwargs)
+
+
+class Qwen2_5OmniConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`Qwen2_5OmniForConditionalGeneration`]. It is used to instantiate a Qwen2.5Omni
+ model according to the specified sub-models configurations, defining the model architecture.
+
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the
+ [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ thinker_config (`dict`, *optional*): Configuration of the underlying thinker sub-model.
+ talker_config (`dict`, *optional*): Configuration of the underlying talker sub-model.
+ token2wav_config (`dict`, *optional*): Configuration of the underlying codec sub-model.
+ enable_audio_output (`bool`, *optional*, defaults to `True`): Whether enable audio output and load talker and token2wav module.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... Qwen2_5OmniThinkerConfig,
+ ... Qwen2_5OmniTalkerConfig,
+ ... Qwen2_5OmniToken2WavConfig,
+ ... Qwen2_5OmniForConditionalGeneration,
+ ... Qwen2_5OmniConfig,
+ ... )
+
+ >>> # Initializing sub-modules configurations.
+ >>> thinker_config = Qwen2_5OmniThinkerConfig()
+ >>> talker_config = Qwen2_5OmniTalkerConfig()
+ >>> token2wav_config = Qwen2_5OmniToken2WavConfig()
+
+
+ >>> # Initializing a module style configuration
+ >>> configuration = Qwen2_5OmniConfig.from_sub_model_configs(
+ ... thinker_config, talker_config, token2wav_config
+ ... )
+
+ >>> # Initializing a model (with random weights)
+ >>> model = Qwen2_5OmniForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "qwen2_5_omni"
+ sub_configs = {
+ "thinker_config": Qwen2_5OmniThinkerConfig,
+ "talker_config": Qwen2_5OmniTalkerConfig,
+ "token2wav_config": Qwen2_5OmniToken2WavConfig,
+ }
+
+ def __init__(
+ self,
+ thinker_config=None,
+ talker_config=None,
+ token2wav_config=None,
+ enable_audio_output: bool = True,
+ **kwargs,
+ ):
+ if thinker_config is None:
+ thinker_config = {}
+ logger.info("thinker_config is None. Initializing thinker model with default values")
+
+ if talker_config is None:
+ talker_config = {}
+ logger.info("talker_config is None. Initializing talker model with default values")
+
+ if token2wav_config is None:
+ token2wav_config = {}
+ logger.info("token2wav_config is None. Initializing token2wav model with default values")
+
+ self.thinker_config = Qwen2_5OmniThinkerConfig(**thinker_config)
+ self.talker_config = Qwen2_5OmniTalkerConfig(**talker_config)
+ self.token2wav_config = Qwen2_5OmniToken2WavConfig(**token2wav_config)
+ self.enable_audio_output = enable_audio_output
+
+ super().__init__(**kwargs)
+
+ def get_text_config(self, *args, **kwargs):
+ """
+ Returns the config that is meant to be used with text IO. On most models, it is the original config instance
+ itself. On specific composite models, it is under a set of valid names.
+
+ Args:
+ decoder (`Optional[bool]`, *optional*, defaults to `False`):
+ If set to `True`, then only search for decoder config names.
+ """
+ # Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
+ # except for Qwen yet. This has to be generalized if more deeply nested configs are
+ # added. NOTE: currently method used only by vLLM
+ return self.thinker_config.get_text_config(*args, **kwargs)
+
+
+class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel):
+ config: Qwen2_5OmniConfig
+ _can_compile_fullgraph = False
+
+
+class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel):
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ self,
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ min_dtype: float,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to place the 4D attention mask on.
+ min_dtype (`float`):
+ The minimum value representable with the dtype `dtype`.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+ def get_llm_pos_ids_for_vision(
+ self,
+ start_idx: int,
+ vision_idx: int,
+ spatial_merge_size: int,
+ t_index: list[int],
+ grid_hs: list[int],
+ grid_ws: list[int],
+ ):
+ llm_pos_ids_list = []
+ llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
+ llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten()
+ t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().long()
+ _llm_pos_ids = torch.stack([t_index, h_index, w_index])
+ llm_pos_ids_list.append(_llm_pos_ids + start_idx) # + 1 ) # 12.09 by malinhan
+ llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
+ return llm_pos_ids
+
+ def get_chunked_index(
+ self, token_indices: torch.Tensor, tokens_per_chunk: int, remove_index: int
+ ) -> list[tuple[int, int]]:
+ """
+ Splits token index list into chunks based on token value ranges.
+
+ Given a list of token indices, returns a list of (start, end) index tuples representing
+ slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`.
+
+ For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that:
+ - the first chunk contains token values < 1000,
+ - the second chunk contains values >= 1000 and < 2000, and so on.
+
+ Parameters:
+ token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of
+ token index values.
+ t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
+ remove_index (`int`) An index id to subtract from `token_indices` before chunking
+
+ Returns:
+ `list[tuple[int, int]]`: A list of tuples, each representing the start (inclusive)
+ and end (exclusive) indices of a chunk in `token_indices`.
+ """
+
+ def _iter():
+ i, start_idx = 0, 0 # skip bos token
+ current_chunk = 1
+ while i < len(token_indices): # skip eos token
+ if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk:
+ yield (start_idx, i)
+ start_idx = i
+ current_chunk += 1
+ i += 1
+ yield (start_idx, len(token_indices))
+
+ return list(_iter())
+
+ def get_rope_index(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ use_audio_in_video: bool = False,
+ audio_seqlens: Optional[torch.LongTensor] = None,
+ second_per_grids: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
+
+ Explanation:
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
+
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
+ Examples:
+ input_ids: [T T T T T], here T is for text.
+ temporal position_ids: [0, 1, 2, 3, 4]
+ height position_ids: [0, 1, 2, 3, 4]
+ width position_ids: [0, 1, 2, 3, 4]
+
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
+ and 1D rotary position embedding for text part.
+ Examples:
+ Temporal (Time): 3 patches, representing different segments of the video in time.
+ Height: 2 patches, dividing each frame vertically.
+ Width: 2 patches, dividing each frame horizontally.
+ We also have some important parameters:
+ fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
+ tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
+ temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
+ interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
+ vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
+ text temporal position_ids: [101, 102, 103, 104, 105]
+ text height position_ids: [101, 102, 103, 104, 105]
+ text width position_ids: [101, 102, 103, 104, 105]
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ use_audio_in_video (`bool`, *optional*):
+ If set to `True`, use the audio in video.
+ audio_seqlens (`torch.LongTensor` of shape `(num_audios)`, *optional*):
+ The length of feature shape of each audio in LLM.
+ second_per_grids (`torch.LongTensor` of shape `(num_videos)`, *optional*):
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
+
+ Returns:
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
+ """
+ spatial_merge_size = self.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ audio_token_id = self.config.audio_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+ audio_start_token_id = self.config.audio_start_token_id
+ position_id_per_seconds = self.config.position_id_per_seconds
+ seconds_per_chunk = self.config.seconds_per_chunk
+
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is not None:
+ attention_mask = attention_mask == 1
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_idx, video_idx, audio_idx = 0, 0, 0
+ for i, input_ids in enumerate(total_input_ids):
+ if attention_mask is not None:
+ input_ids = input_ids[attention_mask[i]]
+ image_nums, video_nums, audio_nums = 0, 0, 0
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ audio_nums = torch.sum(input_ids == audio_start_token_id)
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (
+ (vision_tokens == audio_start_token_id).sum()
+ if use_audio_in_video
+ else (vision_tokens == video_token_id).sum()
+ )
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums
+ multimodal_nums = (
+ image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums
+ )
+ for _ in range(multimodal_nums):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if audio_token_id in input_tokens and remain_audios > 0:
+ ed_audio = input_tokens.index(audio_token_id, st)
+ else:
+ ed_audio = len(input_tokens) + 1
+ min_ed = min(ed_image, ed_video, ed_audio)
+ if min_ed == ed_audio:
+ text_len = min_ed - st - 1
+ if text_len != 0:
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ bos_len = 1
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1
+ llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
+ llm_pos_ids_list.append(llm_pos_ids)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ eos_len = 1
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st += text_len + bos_len + audio_len + eos_len
+ audio_idx += 1
+ remain_audios -= 1
+
+ elif min_ed == ed_image:
+ text_len = min_ed - st - 1
+ if text_len != 0:
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ bos_len = 1
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ grid_t = image_grid_thw[image_idx][0]
+ grid_hs = image_grid_thw[:, 1]
+ grid_ws = image_grid_thw[:, 2]
+ t_index = (torch.arange(grid_t) * 1 * position_id_per_seconds).long()
+ llm_pos_ids = self.get_llm_pos_ids_for_vision(
+ st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
+ )
+ image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
+ llm_pos_ids_list.append(llm_pos_ids)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ eos_len = 1
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st += text_len + bos_len + image_len + eos_len
+ image_idx += 1
+ remain_images -= 1
+
+ elif min_ed == ed_video and not use_audio_in_video:
+ text_len = min_ed - st - 1
+ if text_len != 0:
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ bos_len = 1
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ grid_t = video_grid_thw[video_idx][0]
+ grid_hs = video_grid_thw[:, 1]
+ grid_ws = video_grid_thw[:, 2]
+ t_index = (
+ torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds
+ ).long()
+ llm_pos_ids = self.get_llm_pos_ids_for_vision(
+ st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
+ )
+ video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
+ llm_pos_ids_list.append(llm_pos_ids)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ eos_len = 1
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st += text_len + bos_len + video_len + eos_len
+ video_idx += 1
+ remain_videos -= 1
+
+ elif min_ed == ed_video and use_audio_in_video:
+ text_len = min_ed - st - 2
+ if text_len != 0:
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ bos_len = 1
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
+ llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1
+ audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
+ grid_t = video_grid_thw[video_idx][0]
+ grid_hs = video_grid_thw[:, 1]
+ grid_ws = video_grid_thw[:, 2]
+
+ t_index = (
+ torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds
+ ).long()
+ video_llm_pos_ids = self.get_llm_pos_ids_for_vision(
+ st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
+ )
+
+ t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk)
+ video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx)
+ audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx)
+ sub_len = 0
+ for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
+ video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None
+ audio_chunk_index = audio_chunk_indexes[j] if j < len(audio_chunk_indexes) else None
+ if video_chunk_index is not None:
+ sub_len += video_chunk_index[1] - video_chunk_index[0]
+
+ llm_pos_ids_list.append(
+ video_llm_pos_ids[:, video_chunk_index[0] : video_chunk_index[1]]
+ )
+ if audio_chunk_index is not None:
+ sub_len += audio_chunk_index[1] - audio_chunk_index[0]
+
+ llm_pos_ids_list.append(
+ audio_llm_pos_ids[:, audio_chunk_index[0] : audio_chunk_index[1]]
+ )
+ video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ eos_len = 1
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
+ llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
+
+ st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2
+
+ audio_idx += 1
+ video_idx += 1
+ remain_videos -= 1
+ remain_audios -= 1
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+
+ if attention_mask is not None:
+ position_ids[..., i, attention_mask[i]] = llm_positions.to(position_ids.device)
+ else:
+ position_ids[..., i, :] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=input_ids.device)
+
+ return position_ids, mrope_position_deltas
+ else:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)
+
+ return position_ids, mrope_position_deltas
+
+
+############################
+# Start Thinker #
+############################
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Qwen2.5OmniThinker causal language model (or autoregressive) outputs.
+ """
+)
+class Qwen2_5OmniThinkerCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+class Qwen2_5OmniAudioAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ config: Qwen2_5OmniAudioEncoderConfig,
+ ):
+ super().__init__()
+ self.embed_dim = config.d_model
+ self.num_heads = config.encoder_attention_heads
+ self.dropout = config.attention_dropout
+ self.head_dim = self.embed_dim // self.num_heads
+ self.num_key_value_groups = 1 # needed for eager attention
+ self.config = config
+
+ if (self.head_dim * self.num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = 0.0
+ self.is_decoder = False
+ self.is_causal = False
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ seq_length, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
+ key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
+ value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
+
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, _ = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+
+class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer):
+ def __init__(self, config: Qwen2_5OmniAudioEncoderConfig):
+ super().__init__(config)
+ self.self_attn = Qwen2_5OmniAudioAttention(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ cu_seqlens=cu_seqlens,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = residual + hidden_states
+
+ if hidden_states.dtype == torch.float16:
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ return outputs
+
+
+class SinusoidsPositionEmbedding(nn.Module):
+ def __init__(self, length, channels, max_timescale=10000):
+ super().__init__()
+ if channels % 2 != 0:
+ raise ValueError("SinusoidsPositionEmbedding needs even channels input")
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+ self.register_buffer(
+ "positional_embedding",
+ torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
+ persistent=False,
+ )
+
+ def forward(self, seqlen: int):
+ return self.positional_embedding[:seqlen, :]
+
+
+@auto_docstring(
+ custom_intro="""
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`Qwen2_5OmniAudioEncoderLayer`].
+ """
+)
+class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
+ config: Qwen2_5OmniAudioEncoderConfig
+ main_input_name = "input_features"
+ _no_split_modules = ["Qwen2_5OmniAudioEncoderLayer"]
+ _supports_sdpa = True
+
+ def __init__(self, config: Qwen2_5OmniAudioEncoderConfig):
+ super().__init__(config)
+ self.dropout = config.dropout
+
+ embed_dim = config.d_model
+ self.num_mel_bins = config.num_mel_bins
+ self.max_source_positions = config.max_source_positions
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+ self.n_window = config.n_window
+ self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
+ self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
+ self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim)
+ self.audio_bos_eos_token = nn.Embedding(2, config.output_dim)
+ self.layers = nn.ModuleList([Qwen2_5OmniAudioEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self.ln_post = nn.LayerNorm(config.d_model)
+ self.avg_pooler = nn.AvgPool1d(2, stride=2)
+ self.proj = nn.Linear(config.d_model, config.output_dim)
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _freeze_parameters(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self._requires_grad = False
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.conv1
+
+ def set_input_embeddings(self, value: nn.Module):
+ self.conv1 = value
+
+ def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
+ # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
+ # NOTE: the created attention masl only approximates the ragged FA2 attention by
+ # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
+ # blocks. Though it will not be a 100% match for FA2's `varlen` path
+ if self.config._attn_implementation == "flash_attention_2":
+ return None
+
+ seq_length = inputs_tensor.shape[0]
+ attention_mask = torch.full(
+ [1, 1, seq_length, seq_length],
+ torch.finfo(inputs_tensor.dtype).min,
+ device=inputs_tensor.device,
+ dtype=inputs_tensor.dtype,
+ )
+ for i in range(1, len(cu_seqlens)):
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
+ return attention_mask
+
+ @auto_docstring
+ def forward(
+ self,
+ input_features,
+ feature_lens=None,
+ aftercnn_lens=None,
+ **kwargs,
+ ):
+ r"""
+ feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
+ mel length
+ aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`):
+ mel length after cnn
+ """
+ chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
+
+ chunk_lengths = torch.tensor(
+ [self.n_window * 2] * chunk_num.sum(),
+ dtype=torch.long,
+ device=feature_lens.device,
+ )
+ tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
+ chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
+ chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths)
+
+ chunk_list = input_features.split(chunk_lengths.tolist(), dim=1)
+ padded_feature, padded_mask, padded_mask_after_cnn = self.padded_and_mask_function(
+ chunk_list, chunk_lengths, padding_value=0, padding_side="right"
+ )
+ padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask
+ padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2)
+
+ padded_embed = padded_embed + self.positional_embedding.positional_embedding[
+ : padded_embed.shape[1], :
+ ].unsqueeze(0).to(padded_embed.dtype)
+ hidden_states = padded_embed[padded_mask_after_cnn]
+ cu_seqlens = torch.cat(
+ (
+ torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32),
+ padded_mask_after_cnn.sum(1).cumsum(0),
+ )
+ ).to(torch.int32)
+ attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
+
+ for encoder_layer in self.layers:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
+ token_audio_list = []
+ for each_audio_states in hidden_states_list:
+ each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1)
+ each_audio_states = self.ln_post(each_audio_states)
+ each_audio_states = self.proj(each_audio_states)
+ token_audio_list.append(each_audio_states)
+ token_audio = torch.cat(token_audio_list, dim=0)
+ return BaseModelOutput(last_hidden_state=token_audio)
+
+ def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"):
+ """
+ Pads a sequence of tensors to their maximum length on indicated `padding_side`.
+ Then prepares a mask so that pad tokens are not attended to.
+ """
+ max_len = tensor_len.max()
+ dim = tensor_list[0].shape[0]
+ padded_tensor = torch.full(
+ size=(len(tensor_list), dim, max_len),
+ fill_value=padding_value,
+ dtype=self.dtype,
+ device=tensor_list[0].device,
+ )
+
+ batch_mask = torch.zeros(
+ (len(tensor_len), max_len),
+ dtype=torch.long,
+ device=padded_tensor.device,
+ )
+ for i, length in enumerate(tensor_len):
+ batch_mask[i, :length] = 1
+ padded_tensor[i, :, :length] = tensor_list[i]
+
+ feature_lens_after_cnn = (tensor_len - 1) // 2 + 1
+ max_len_after_cnn = feature_lens_after_cnn.max()
+ batch_mask_after_cnn = torch.zeros(
+ (len(tensor_len), max_len_after_cnn),
+ dtype=torch.long,
+ device=padded_tensor.device,
+ )
+ for i, length in enumerate(feature_lens_after_cnn):
+ batch_mask_after_cnn[i, :length] = 1
+ return (
+ padded_tensor,
+ batch_mask.unsqueeze(1),
+ batch_mask_after_cnn.bool(),
+ )
+
+ # Ignore copy
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
+ """
+ Computes the output length of the convolutional layers and the output length of the audio encoder
+ """
+ input_lengths = (input_lengths - 1) // 2 + 1
+ output_lengths = (input_lengths - 2) // 2 + 1
+ return input_lengths, output_lengths
+
+
+def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
+ orig_dtype = tensor.dtype
+ tensor = tensor.float()
+ cos = freqs.cos()
+ sin = freqs.sin()
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
+ output = output.to(orig_dtype)
+ return output
+
+
+class Qwen2_5OmniVisionAttention(nn.Module):
+ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None:
+ super().__init__()
+ self.dim = config.hidden_size
+ self.num_heads = config.num_heads
+ self.head_dim = self.dim // self.num_heads
+ self.q = nn.Linear(self.dim, self.dim, bias=True)
+ self.k = nn.Linear(self.dim, self.dim, bias=True)
+ self.v = nn.Linear(self.dim, self.dim, bias=True)
+ self.proj = nn.Linear(self.dim, self.dim)
+ self.scaling = self.head_dim**-0.5
+ self.num_key_value_groups = 1 # needed for eager attention
+ self.config = config
+ self.attention_dropout = 0.0
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ query_states = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
+ key_states = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
+ value_states = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
+
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ if self.config._attn_implementation == "flash_attention_2":
+ # Flash Attention 2: Use cu_seqlens for variable length attention
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+ attn_output, _ = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ cu_seq_lens_q=cu_seqlens,
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+ else:
+ # Other implementations: Process each chunk separately
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+ splits = [
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
+ ]
+
+ attn_outputs = [
+ attention_interface(
+ self,
+ q,
+ k,
+ v,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ is_causal=False,
+ **kwargs,
+ )[0]
+ for q, k, v in zip(*splits)
+ ]
+ attn_output = torch.cat(attn_outputs, dim=1)
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock):
+ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None:
+ super().__init__(config, config._attn_implementation)
+ self.attn = Qwen2_5OmniVisionAttention(config=config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ **kwargs,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
+ config: Qwen2_5OmniVisionEncoderConfig
+ _no_split_modules = ["Qwen2_5OmniVisionBlock"]
+
+ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)])
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
+ The final hidden states of the model.
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Returns:
+ `torch.Tensor`: hidden_states.
+ """
+ hidden_states = self.patch_embed(hidden_states)
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
+ cu_window_seqlens = torch.tensor(
+ cu_window_seqlens,
+ device=hidden_states.device,
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
+
+ seq_len, _ = hidden_states.size()
+ hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ hidden_states = hidden_states[window_index, :, :]
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ # Modification here
+ for layer_num, blk in enumerate(self.blocks):
+ if layer_num in self.fullatt_block_indexes:
+ cu_seqlens_now = cu_seqlens
+ else:
+ cu_seqlens_now = cu_window_seqlens
+
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens_now,
+ rotary_pos_emb=rotary_pos_emb,
+ **kwargs,
+ )
+ hidden_states = self.merger(hidden_states)
+ reverse_indices = torch.argsort(window_index)
+ hidden_states = hidden_states[reverse_indices, :]
+
+ return hidden_states
+
+
+class Qwen2_5OmniRotaryEmbedding(Qwen2VLRotaryEmbedding):
+ def __init__(self, config: Qwen2_5OmniThinkerConfig, device=None):
+ super().__init__(config, device)
+
+
+# It's same as `Qwen2_5_VLAttention`, but talker model's hidden_size isn't divisible by num_heads.
+# Removes the value error as a workaround.
+class Qwen2_5OmniAttention(Qwen2_5_VLAttention):
+ def __init__(self, config: Qwen2_5OmniConfig, layer_idx: Optional[int] = None):
+ nn.Module.__init__(self)
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+ self.rope_scaling = config.rope_scaling
+ self.scaling = self.head_dim**-0.5
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+
+ self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config)
+
+
+class Qwen2MLP(Qwen2_5_VLMLP):
+ pass
+
+
+class Qwen2_5OmniThinkerTextModel(Qwen2_5_VLTextModel):
+ config: Qwen2_5OmniTextConfig
+ _no_split_modules = ["Qwen2_5OmniDecoderLayer"]
+
+ def __init__(self, config: Qwen2_5OmniTextConfig):
+ super().__init__(config)
+
+
+@auto_docstring(
+ custom_intro="""
+ The Qwen2.5OmniThinker model which consists of a audio backbone and a language model.
+ """
+)
+class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin):
+ config: Qwen2_5OmniThinkerConfig
+ base_model_prefix = "thinker"
+ _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
+ _no_split_modules = ["Qwen2_5OmniAudioEncoder", "Qwen2_5OmniVisionEncoder"]
+
+ def __init__(self, config: Qwen2_5OmniThinkerConfig):
+ super().__init__(config)
+ self.audio_tower = Qwen2_5OmniAudioEncoder._from_config(config.audio_config)
+ self.visual = Qwen2_5OmniVisionEncoder._from_config(config.vision_config)
+ self.vocab_size = config.text_config.vocab_size
+ self.model = Qwen2_5OmniThinkerTextModel._from_config(config.text_config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ self.spatial_merge_size = config.vision_config.spatial_merge_size
+ self.rope_deltas = None
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ """
+ Encodes videos into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
+ return video_embeds
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ pixel_values = pixel_values.type(self.visual.dtype)
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
+ return image_embeds
+
+ def get_audio_features(
+ self,
+ input_features: torch.FloatTensor,
+ feature_attention_mask: Optional[torch.LongTensor] = None,
+ audio_feature_lengths: Optional[torch.LongTensor] = None,
+ ):
+ """
+ Encodes audios into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ input_features (`torch.FloatTensor`):
+ The tensors corresponding to the input audios.
+ feature_attention_mask (`torch.LongTensor`, *optional*):
+ Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
+ audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
+ The length of feature shape of each audio in LLM.
+ """
+ if feature_attention_mask is not None:
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
+ input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
+ else:
+ audio_feature_lengths = None
+
+ audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
+ audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
+ )
+ feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
+ audio_outputs = self.audio_tower(
+ input_features,
+ feature_lens=feature_lens,
+ aftercnn_lens=audio_feat_lengths,
+ )
+ audio_features = audio_outputs.last_hidden_state
+
+ if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
+ raise ValueError("length of audio_features should match audio_output_lengths")
+
+ return audio_features
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: Optional[torch.FloatTensor] = None,
+ video_features: Optional[torch.FloatTensor] = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ special_audio_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ ).all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.video_token_id
+ special_audio_mask = input_ids == self.config.audio_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
+ raise ValueError(
+ f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
+ )
+
+ special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ return special_image_mask, special_video_mask, special_audio_mask
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ input_features: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ feature_attention_mask: Optional[torch.Tensor] = None,
+ audio_feature_lengths: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ use_audio_in_video: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ video_second_per_grid: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
+ The length of feature shape of each audio in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ use_audio_in_video (`bool`, *optional*):
+ Whether or not use audio track in video, should same as the parameter in `process_audio_info`.
+ video_second_per_grid (`torch.LongTensor` of shape `(num_videos)`, *optional*):
+ Number of seconds per grid for each video, used for temporal feature mapping.
+
+ Example:
+
+ ```python
+ >>> from io import BytesIO
+ >>> from urllib.request import urlopen
+ >>> import librosa
+ >>> from qwen_vl_utils import process_vision_info
+ >>> from transformers import Qwen2_5OmniProcessor, Qwen2_5OmniThinkerForConditionalGeneration
+
+ >>> thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-7B")
+ >>> processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
+
+ >>> conversations = [
+ >>> {'role': 'system', 'content': 'You are a helpful voice chat bot, and please respond to me in a casual conversation manner using random voice.'},
+ >>> {"role": "user", "content": [
+ >>> {"type": "image", "image_url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
+ >>> {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"},
+ >>> ]},
+ >>> ]
+
+ >>> text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
+ >>> audios = [ librosa.load(BytesIO(urlopen( conversations[1]['content'][1]['audio_url'] ).read()), sr=self.processor.feature_extractor.sampling_rate) ]
+ >>> images, videos = process_vision_info(conversations)
+ >>> inputs = processor(text=text, audios=audios, images=images, videos=videos, return_tensors="pt", padding=True)
+
+ >>> # Generate
+ >>> inputs['use_audio_in_video'] = `True` or `False`
+ >>> generation = thinker.generate(**inputs, max_new_tokens=2048)
+ >>> generate_ids = generation[:, inputs.input_ids.size(1):]
+
+ >>> response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is None:
+ # 1. Extract the input embeddings
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # 2. Merge text , audios , image and video
+ if input_features is not None:
+ audio_features = self.get_audio_features(
+ input_features,
+ feature_attention_mask=feature_attention_mask,
+ audio_feature_lengths=audio_feature_lengths,
+ )
+ audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
+ inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
+
+ if pixel_values is not None:
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw)
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask, _, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
+ _, video_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ if feature_attention_mask is not None:
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
+ else:
+ audio_feature_lengths = None
+
+ if attention_mask is not None and position_ids is None:
+ if (
+ cache_position is None
+ or (cache_position is not None and cache_position[0] == 0)
+ or self.rope_deltas is None
+ ):
+ delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ video_grid_thw,
+ attention_mask,
+ use_audio_in_video,
+ audio_feature_lengths,
+ video_second_per_grid,
+ )
+ rope_deltas = rope_deltas - delta0
+ self.rope_deltas = rope_deltas
+ else:
+ batch_size, seq_length = input_ids.shape
+ delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
+ position_ids = torch.arange(seq_length, device=input_ids.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ outputs = self.model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs
+ return (loss,) + output if loss is not None else output
+
+ return Qwen2_5OmniThinkerCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ input_features=None,
+ feature_attention_mask=None,
+ use_audio_in_video=False,
+ video_second_per_grid=None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ use_cache=use_cache,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ input_features=input_features,
+ feature_attention_mask=feature_attention_mask,
+ use_audio_in_video=use_audio_in_video,
+ video_second_per_grid=video_second_per_grid,
+ **kwargs,
+ )
+
+ model_inputs["position_ids"] = None
+
+ if cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_values_videos"] = None
+ model_inputs["input_features"] = None
+
+ return model_inputs
+
+
+############################
+# Start Talker #
+############################
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Qwen2.5OmniTalker causal language model (or autoregressive) outputs.
+ """
+)
+class Qwen2_5OmniTalkerCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ thinker_reply_part (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Hidden states from the thinker model that are used as input for the talker model. These represent the encoded
+ response that the talker model will use to generate speech tokens.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+ thinker_reply_part: Optional[torch.FloatTensor] = None
+
+
+class Qwen2_5OmniTalkerModel(Qwen2_5_VLTextModel):
+ config: Qwen2_5OmniTalkerConfig
+ _no_split_modules = ["Qwen2_5OmniTalkerDecoderLayer"]
+
+ def __init__(self, config: Qwen2_5OmniTalkerConfig):
+ super().__init__(config)
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_size, self.padding_idx)
+
+
+class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin):
+ config: Qwen2_5OmniTalkerConfig
+ base_model_prefix = "talker"
+
+ def __init__(self, config: Qwen2_5OmniTalkerConfig):
+ super().__init__(config)
+
+ self.thinker_to_talker_proj = nn.Linear(config.embedding_size, config.hidden_size)
+
+ self.model = Qwen2_5OmniTalkerModel(config)
+ self.codebook_size = config.vocab_size
+ self.codec_head = nn.Linear(config.hidden_size, self.codebook_size, bias=False)
+
+ self.codec_bos_token = config.tts_codec_start_token_id
+ self.codec_eos_token = config.tts_codec_end_token_id
+ self.codec_pad_token = config.tts_codec_pad_token_id
+ self.codec_mask_token = config.tts_codec_mask_token_id
+
+ self.text_bos_token = config.tts_text_start_token_id
+ self.text_eos_token = config.tts_text_end_token_id
+ self.text_pad_token = config.tts_text_pad_token_id
+
+ self.spatial_merge_size = self.config.spatial_merge_size
+ self.rope_deltas = None
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ thinker_reply_part: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ input_text_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ use_audio_in_video: Optional[bool] = None,
+ audio_feature_lengths: Optional[torch.LongTensor] = None,
+ video_second_per_grid: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, Qwen2_5OmniTalkerCausalLMOutputWithPast]:
+ r"""
+ thinker_reply_part (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Hidden states from the thinker model's output that represent the text reply part to be processed.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ input_text_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Input token IDs for text-only content, used for position calculation in multimodal contexts.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ use_audio_in_video (`bool`, *optional*):
+ Whether or not use audio track in video, should same as the parameter in `process_audio_info`.
+ audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
+ The length of feature shape of each audio in LLM.
+ video_second_per_grid (`torch.LongTensor` of shape `(num_videos)`, *optional*):
+ Number of seconds per grid for each video, used for temporal feature mapping.
+
+ Example:
+
+ ```python
+ >>> from io import BytesIO
+ >>> from urllib.request import urlopen
+ >>> import librosa
+ >>> from transformers import AutoProcessor, Qwen2_5OmniTalkerForConditionalGeneration
+
+ >>> model = Qwen2_5OmniTalkerForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B")
+
+ >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:"
+ >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
+ >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate)
+
+ >>> inputs = processor(text=prompt, audios=audio, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_length=30)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Generate the caption in English: Glass is breaking."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if attention_mask is not None and position_ids is None:
+ if (
+ cache_position is None
+ or (cache_position is not None and cache_position[0] == 0)
+ or self.rope_deltas is None
+ ):
+ position_ids, rope_deltas = self.get_rope_index(
+ input_text_ids,
+ image_grid_thw,
+ video_grid_thw,
+ attention_mask,
+ use_audio_in_video,
+ audio_feature_lengths,
+ video_second_per_grid,
+ )
+
+ inputs_embeds[:, -1, :] += self.get_input_embeddings()(
+ torch.tensor([self.codec_bos_token], dtype=torch.long, device=inputs_embeds.device)
+ )
+ inputs_embeds[:, -2, :] += self.get_input_embeddings()(
+ torch.tensor([self.codec_pad_token], dtype=torch.long, device=inputs_embeds.device)
+ )
+ self.rope_deltas = rope_deltas
+
+ else:
+ batch_size, seq_length = input_ids.shape
+ delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
+ position_ids = torch.arange(seq_length, device=input_ids.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ if inputs_embeds is None:
+ # 1. Inference tokens after second token
+ codec_embeds = self.get_input_embeddings()(input_ids)
+ inputs_embeds = codec_embeds + thinker_reply_part[:, :1, :]
+ if thinker_reply_part.shape[1] > 1:
+ thinker_reply_part = thinker_reply_part[:, 1:, :]
+
+ talker_lm_input = self.thinker_to_talker_proj(inputs_embeds)
+
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(inputs_embeds.device)
+
+ outputs = self.model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=talker_lm_input,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.codec_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return Qwen2_5OmniTalkerCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ thinker_reply_part=thinker_reply_part,
+ )
+
+ def _get_initial_cache_position(self, seq_length, device, model_kwargs):
+ # Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily
+ inputs_embeds = model_kwargs.pop("inputs_embeds")
+ model_kwargs = super()._get_initial_cache_position(seq_length, device, model_kwargs)
+ model_kwargs["inputs_embeds"] = inputs_embeds
+ return model_kwargs
+
+ # prepare inputs for talker lm generation
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ input_text_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ thinker_reply_part=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ input_audio_features=None,
+ audio_feature_attention_mask=None,
+ audio_feature_lengths=None,
+ use_audio_in_video=False,
+ video_second_per_grid=None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values,
+ attention_mask,
+ inputs_embeds,
+ cache_position,
+ use_cache=use_cache,
+ thinker_reply_part=thinker_reply_part,
+ input_text_ids=input_text_ids,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ use_audio_in_video=use_audio_in_video,
+ audio_feature_lengths=audio_feature_lengths,
+ video_second_per_grid=video_second_per_grid,
+ **kwargs,
+ )
+
+ model_inputs["position_ids"] = None
+
+ return model_inputs
+
+ def _update_model_kwargs_for_generation(
+ self,
+ outputs: ModelOutput,
+ model_kwargs: dict[str, Any],
+ is_encoder_decoder: bool = False,
+ num_new_tokens: int = 1,
+ ) -> dict[str, Any]:
+ model_kwargs = super()._update_model_kwargs_for_generation(
+ outputs, model_kwargs, is_encoder_decoder, num_new_tokens
+ )
+
+ if getattr(outputs, "thinker_reply_part", None) is not None:
+ model_kwargs["thinker_reply_part"] = outputs.thinker_reply_part
+
+ return model_kwargs
+
+
+############################
+# Start Token2Wav #
+############################
+
+
+# Using custom RoPE, will use LlamaRotaryEmbedding next version
+class Qwen2_5OmniDiTRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim, base=10000):
+ super().__init__()
+
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, x):
+ batch_size, seq_len = x.shape[0], x.shape[1]
+ t = torch.arange(seq_len, device=x.device)
+ device_type = x.device.type
+ device_type = device_type if device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float()
+ freqs = torch.stack((freqs, freqs), dim=-1)
+ freqs = freqs.reshape(*freqs.shape[:-2], -1)
+ freqs = freqs.repeat(batch_size, *([1] * freqs.dim()))
+ cos = freqs.cos()
+ sin = freqs.sin()
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Modified from Llama with a different rotate function, will fixed in next release
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+
+ def rotate_half_codec(x):
+ # x = rearrange(x, "... (d r) -> ... d r", r=2)
+ x = x.reshape(*x.shape[:-1], -1, 2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return x.reshape(*x.shape[:-2], -1)
+
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half_codec(q) * sin)
+ k_embed = (k * cos) + (rotate_half_codec(k) * sin)
+ return q_embed, k_embed
+
+
+class TimeDelayNetBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ dilation,
+ ):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ padding="same",
+ padding_mode="reflect",
+ )
+ self.activation = nn.ReLU()
+
+ def forward(self, hidden_states: torch.Tensor):
+ return self.activation(self.conv(hidden_states))
+
+
+class Res2NetBlock(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
+ super().__init__()
+
+ in_channel = in_channels // scale
+ hidden_channel = out_channels // scale
+
+ self.blocks = nn.ModuleList(
+ [
+ TimeDelayNetBlock(
+ in_channel,
+ hidden_channel,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ )
+ for i in range(scale - 1)
+ ]
+ )
+ self.scale = scale
+
+ def forward(self, hidden_states):
+ outputs = []
+ for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)):
+ if i == 0:
+ output_part = hidden_part
+ elif i == 1:
+ output_part = self.blocks[i - 1](hidden_part)
+ else:
+ output_part = self.blocks[i - 1](hidden_part + output_part)
+ outputs.append(output_part)
+ output = torch.cat(outputs, dim=1)
+ return output
+
+
+class SqueezeExcitationBlock(nn.Module):
+ def __init__(self, in_channels, se_channels, out_channels):
+ super().__init__()
+
+ self.conv1 = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=se_channels,
+ kernel_size=1,
+ padding="same",
+ padding_mode="reflect",
+ )
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv1d(
+ in_channels=se_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ padding="same",
+ padding_mode="reflect",
+ )
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, hidden_states):
+ hidden_states_mean = hidden_states.mean(dim=2, keepdim=True)
+
+ hidden_states_mean = self.relu(self.conv1(hidden_states_mean))
+ hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean))
+
+ return hidden_states * hidden_states_mean
+
+
+class AttentiveStatisticsPooling(nn.Module):
+ """This class implements an attentive statistic pooling layer for each channel.
+ It returns the concatenated mean and std of the input tensor.
+ """
+
+ def __init__(self, channels, attention_channels=128):
+ super().__init__()
+
+ self.eps = 1e-12
+ self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1)
+ self.tanh = nn.Tanh()
+ self.conv = nn.Conv1d(
+ in_channels=attention_channels,
+ out_channels=channels,
+ kernel_size=1,
+ padding="same",
+ padding_mode="reflect",
+ )
+
+ def _length_to_mask(self, length, max_len=None, dtype=None, device=None):
+ """Creates a binary mask for each sequence.
+
+ Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
+
+ Arguments
+ ---------
+ length : torch.LongTensor
+ Containing the length of each sequence in the batch. Must be 1D.
+ max_len : int
+ Max length for the mask, also the size of the second dimension.
+ dtype : torch.dtype, default: None
+ The dtype of the generated mask.
+ device: torch.device, default: None
+ The device to put the mask variable.
+
+ Returns
+ -------
+ mask : tensor
+ The binary mask.
+ """
+
+ if max_len is None:
+ max_len = length.max().long().item() # using arange to generate mask
+ mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
+ len(length), max_len
+ ) < length.unsqueeze(1)
+
+ mask = torch.as_tensor(mask, dtype=dtype, device=device)
+ return mask
+
+ def _compute_statistics(self, x, m, dim=2):
+ mean = (m * x).sum(dim)
+ std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps))
+ return mean, std
+
+ def forward(self, hidden_states):
+ seq_length = hidden_states.shape[-1]
+ lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device)
+
+ # Make binary mask of shape [N, 1, L]
+ mask = self._length_to_mask(
+ lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device
+ )
+ mask = mask.unsqueeze(1)
+
+ # Expand the temporal context of the pooling layer by allowing the
+ # self-attention to look at global properties of the utterance.
+ total = mask.sum(dim=2, keepdim=True)
+
+ mean, std = self._compute_statistics(hidden_states, mask / total)
+ mean = mean.unsqueeze(2).repeat(1, 1, seq_length)
+ std = std.unsqueeze(2).repeat(1, 1, seq_length)
+ attention = torch.cat([hidden_states, mean, std], dim=1)
+
+ # Apply layers
+ attention = self.conv(self.tanh(self.tdnn(attention)))
+
+ # Filter out zero-paddings
+ attention = attention.masked_fill(mask == 0, float("-inf"))
+
+ attention = F.softmax(attention, dim=2)
+ mean, std = self._compute_statistics(hidden_states, attention)
+ # Append mean and std of the batch
+ pooled_stats = torch.cat((mean, std), dim=1)
+ pooled_stats = pooled_stats.unsqueeze(2)
+
+ return pooled_stats
+
+
+class SqueezeExcitationRes2NetBlock(nn.Module):
+ """An implementation of building block in ECAPA-TDNN, i.e.,
+ TDNN-Res2Net-TDNN-SqueezeExcitationBlock.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ res2net_scale=8,
+ se_channels=128,
+ kernel_size=1,
+ dilation=1,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.tdnn1 = TimeDelayNetBlock(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ dilation=1,
+ )
+ self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation)
+ self.tdnn2 = TimeDelayNetBlock(
+ out_channels,
+ out_channels,
+ kernel_size=1,
+ dilation=1,
+ )
+ self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels)
+
+ def forward(self, hidden_state):
+ residual = hidden_state
+
+ hidden_state = self.tdnn1(hidden_state)
+ hidden_state = self.res2net_block(hidden_state)
+ hidden_state = self.tdnn2(hidden_state)
+ hidden_state = self.se_block(hidden_state)
+
+ return hidden_state + residual
+
+
+class ECAPA_TimeDelayNet(torch.nn.Module):
+ """An implementation of the speaker embedding model in a paper.
+ "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
+ TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143).
+ """
+
+ def __init__(self, config: Qwen2_5OmniDiTConfig):
+ super().__init__()
+ if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len(
+ config.enc_dilations
+ ):
+ raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length")
+ self.channels = config.enc_channels
+ self.blocks = nn.ModuleList()
+
+ # The initial TDNN layer
+ self.blocks.append(
+ TimeDelayNetBlock(
+ config.mel_dim,
+ config.enc_channels[0],
+ config.enc_kernel_sizes[0],
+ config.enc_dilations[0],
+ )
+ )
+
+ # SE-Res2Net layers
+ for i in range(1, len(config.enc_channels) - 1):
+ self.blocks.append(
+ SqueezeExcitationRes2NetBlock(
+ config.enc_channels[i - 1],
+ config.enc_channels[i],
+ res2net_scale=config.enc_res2net_scale,
+ se_channels=config.enc_se_channels,
+ kernel_size=config.enc_kernel_sizes[i],
+ dilation=config.enc_dilations[i],
+ )
+ )
+
+ # Multi-layer feature aggregation
+ self.mfa = TimeDelayNetBlock(
+ config.enc_channels[-1],
+ config.enc_channels[-1],
+ config.enc_kernel_sizes[-1],
+ config.enc_dilations[-1],
+ )
+
+ # Attentive Statistical Pooling
+ self.asp = AttentiveStatisticsPooling(
+ config.enc_channels[-1],
+ attention_channels=config.enc_attention_channels,
+ )
+
+ # Final linear transformation
+ self.fc = nn.Conv1d(
+ in_channels=config.enc_channels[-1] * 2,
+ out_channels=config.enc_dim,
+ kernel_size=1,
+ padding="same",
+ padding_mode="reflect",
+ )
+
+ def forward(self, hidden_states):
+ # Minimize transpose for efficiency
+ hidden_states = hidden_states.transpose(1, 2)
+
+ hidden_states_list = []
+ for layer in self.blocks:
+ hidden_states = layer(hidden_states)
+ hidden_states_list.append(hidden_states)
+
+ # Multi-layer feature aggregation
+ hidden_states = torch.cat(hidden_states_list[1:], dim=1)
+ hidden_states = self.mfa(hidden_states)
+
+ # Attentive Statistical Pooling
+ hidden_states = self.asp(hidden_states)
+
+ # Final linear transformation
+ hidden_states = self.fc(hidden_states)
+
+ hidden_states = hidden_states.squeeze(-1)
+ return hidden_states
+
+
+class DiTInputEmbedding(nn.Module):
+ def __init__(self, config: Qwen2_5OmniDiTConfig):
+ super().__init__()
+ self.proj = nn.Linear(
+ config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim,
+ config.hidden_size,
+ )
+ self.spk_encoder = ECAPA_TimeDelayNet(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ speaker_embedding: torch.Tensor,
+ condition_vector: torch.Tensor,
+ code_embed: torch.Tensor,
+ drop_audio_cond: Optional[bool] = False,
+ code_embed_uncond: Optional[bool] = None,
+ apply_cfg: Optional[bool] = True,
+ ):
+ if apply_cfg:
+ hidden_states = torch.cat([hidden_states, hidden_states], dim=0)
+ speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0)
+ condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0)
+ code_embed = torch.cat([code_embed, code_embed_uncond], dim=0)
+ elif drop_audio_cond: # cfg for cond audio
+ condition_vector = torch.zeros_like(condition_vector)
+ speaker_embedding = torch.zeros_like(speaker_embedding)
+ condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1)
+ hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1))
+
+ return hidden_states
+
+
+# Transformer backbone using DiT blocks
+class DiTCodecEmbedding(nn.Module):
+ def __init__(self, codec_num_embeds, codec_dim, repeats):
+ super().__init__()
+ self.repeats = repeats
+ self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim)
+
+ def forward(self, code, drop_code=False):
+ if drop_code:
+ code = torch.zeros_like(code)
+ code_embed = self.codec_embed(code)
+
+ code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1)
+ return code_embed
+
+
+# AdaLayerNormZero
+# return with modulated x for attn input, and params for later mlp modulation
+class Qwen2_5_OmniAdaLayerNormZero(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(dim, dim * 6)
+
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+
+ def forward(self, hidden_states, emb=None):
+ emb = self.linear(self.silu(emb))
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
+
+ hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+# AdaLayerNormZero for final layer
+# return only with modulated x for attn input, cuz no more mlp modulation
+class Qwen2_5_OmniAdaLayerNormZero_Final(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(dim, dim * 2)
+
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+
+ def forward(self, hidden_states, emb):
+ emb = self.linear(self.silu(emb))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+
+ hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return hidden_states
+
+
+# FeedForward
+class DiTMLP(nn.Module):
+ def __init__(self, dim, mult=4, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+
+ self.ff = nn.ModuleList(
+ [
+ nn.Linear(dim, inner_dim),
+ nn.GELU(approximate="tanh"),
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim),
+ ]
+ )
+
+ def forward(self, hidden_states):
+ for layer in self.ff:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+
+class DiTAttention(nn.Module):
+ def __init__(self, config: Qwen2_5OmniDiTConfig):
+ super().__init__()
+
+ self.config = config
+ self.dim = config.hidden_size
+ self.heads = config.num_attention_heads
+ self.inner_dim = config.head_dim * config.num_attention_heads
+ self.dropout = config.dropout
+ self.is_causal = False
+
+ self.to_q = nn.Linear(config.hidden_size, self.inner_dim)
+ self.to_k = nn.Linear(config.hidden_size, self.inner_dim)
+ self.to_v = nn.Linear(config.hidden_size, self.inner_dim)
+
+ self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)])
+
+ def forward(
+ self,
+ hidden_states, # noised input x
+ position_embeddings=None, # rotary position embedding for x
+ attention_mask=None,
+ ) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ query = self.to_q(hidden_states)
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+
+ # attention
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // self.heads
+ query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+
+ # apply rotary position embedding
+ # Due to training process, only first head is applied with RoPE, will be fixed at next release
+ cos, sin = position_embeddings
+ query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin)
+
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ attention_weights, _ = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=attention_mask,
+ is_causal=False,
+ )
+
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
+ attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim)
+ attention_weights = attention_weights.to(query.dtype)
+
+ # linear proj
+ attention_output = self.to_out[0](attention_weights)
+ attention_output = self.to_out[1](attention_output)
+
+ return attention_output
+
+
+# time step conditioning embedding
+class SinusPositionEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, hidden_states, scale=1000):
+ device = hidden_states.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
+ emb = scale * hidden_states.unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb.type_as(hidden_states)
+
+
+class DiTTimestepEmbedding(nn.Module):
+ def __init__(self, dim, freq_embed_dim=256):
+ super().__init__()
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
+ self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)])
+
+ def forward(self, timestep):
+ time_hidden = self.time_embed(timestep)
+ time_hidden = time_hidden.to(timestep.dtype)
+ for layer in self.time_mlp:
+ time_hidden = layer(time_hidden) # b d
+ return time_hidden
+
+
+class DiTDecoderLayer(nn.Module):
+ def __init__(self, config: Qwen2_5OmniDiTConfig, look_ahead_block=0, look_backward_block=0):
+ super().__init__()
+ self.attn_norm = Qwen2_5_OmniAdaLayerNormZero(config.hidden_size)
+
+ self.attn = DiTAttention(config)
+ self.look_ahead_block = look_ahead_block
+ self.look_backward_block = look_backward_block
+ self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout)
+
+ def forward(
+ self, hidden_states, timestep, position_embeddings=None, block_diff=None
+ ): # x: noised input, t: time embedding
+ # pre-norm & modulation for attention input
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep)
+
+ # attention
+ attn_output = self.attn(
+ hidden_states=norm,
+ position_embeddings=position_embeddings,
+ attention_mask=(block_diff >= -float(self.look_backward_block))
+ & (block_diff <= float(self.look_ahead_block)),
+ )
+
+ # process attention output for input x
+ hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output
+
+ norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ ff_output = self.ff(norm)
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
+
+ return hidden_states
+
+
+class SnakeBeta(nn.Module):
+ """
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ References:
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://huggingface.co/papers/2006.08195
+ """
+
+ def __init__(self, in_features, alpha=1.0):
+ super().__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, hidden_states):
+ """
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
+ """
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
+ torch.sin(hidden_states * alpha), 2
+ )
+
+ return hidden_states
+
+
+def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
+ """Generates a 1D Kaiser-windowed sinc filter.
+
+ Args:
+ cutoff (float): Normalized cutoff frequency (0 to 0.5).
+ half_width (float): Transition bandwidth.
+ kernel_size (int): Number of filter taps.
+
+ Returns:
+ torch.Tensor: A tensor of shape (1, 1, kernel_size) representing the filter.
+ """
+ is_even = kernel_size % 2 == 0
+ half_size = kernel_size // 2
+
+ # Compute Kaiser window parameters
+ delta_f = 4 * half_width
+ attenuation = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
+
+ if attenuation > 50.0:
+ beta = 0.1102 * (attenuation - 8.7)
+ elif attenuation >= 21.0:
+ beta = 0.5842 * (attenuation - 21) ** 0.4 + 0.07886 * (attenuation - 21.0)
+ else:
+ beta = 0.0
+
+ kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32)
+
+ # Compute time indices
+ if is_even:
+ time_indices = torch.arange(-half_size, half_size) + 0.5
+ else:
+ time_indices = torch.arange(kernel_size) - half_size
+
+ # Compute sinc filter
+ if cutoff == 0:
+ return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape
+
+ sinc_filter = torch.sinc(2 * cutoff * time_indices)
+ normalized_filter = 2 * cutoff * kaiser_window * sinc_filter
+
+ # Normalize to ensure sum = 1 (avoid leakage of constant component)
+ normalized_filter /= normalized_filter.sum()
+
+ return normalized_filter.view(1, 1, kernel_size)
+
+
+class UpSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ self.stride = ratio
+ self.pad = self.kernel_size // ratio - 1
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
+
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
+ self.register_buffer("filter", filter, persistent=False)
+
+ def forward(self, hidden_states):
+ channels = hidden_states.shape[1]
+
+ hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate")
+ hidden_states = self.ratio * F.conv_transpose1d(
+ hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels
+ )
+ hidden_states = hidden_states[..., self.pad_left : -self.pad_right]
+
+ return hidden_states
+
+
+class DownSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ cutoff = 0.5 / ratio
+ half_width = 0.6 / ratio
+
+ if cutoff < 0.0:
+ raise ValueError("Minimum cutoff must be larger than zero.")
+ if cutoff > 0.5:
+ raise ValueError("A cutoff above 0.5 does not make sense.")
+
+ self.even = kernel_size % 2 == 0
+ self.pad_left = kernel_size // 2 - int(self.even)
+ self.pad_right = kernel_size // 2
+ self.stride = ratio
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
+ self.register_buffer("filter", filter, persistent=False)
+
+ def forward(self, hidden_states):
+ channels = hidden_states.shape[1]
+ hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate")
+ out = F.conv1d(hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels)
+ return out
+
+
+class TorchActivation1d(nn.Module):
+ def __init__(
+ self,
+ activation,
+ up_ratio: int = 2,
+ down_ratio: int = 2,
+ up_kernel_size: int = 12,
+ down_kernel_size: int = 12,
+ ):
+ super().__init__()
+ if not callable(activation):
+ raise TypeError("Activation function must be callable")
+ self.act = activation
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
+
+ def forward(self, hidden_states):
+ hidden_states = self.upsample(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.downsample(hidden_states)
+
+ return hidden_states
+
+
+class AMPBlock(torch.nn.Module):
+ def __init__(
+ self,
+ channels,
+ kernel_size=3,
+ dilation=(1, 3, 5),
+ ):
+ super().__init__()
+
+ self.convs1 = nn.ModuleList(
+ [
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=self._get_padding(kernel_size, dilation[0]),
+ ),
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=self._get_padding(kernel_size, dilation[1]),
+ ),
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=self._get_padding(kernel_size, dilation[2]),
+ ),
+ ]
+ )
+
+ self.convs2 = nn.ModuleList(
+ [
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self._get_padding(kernel_size, 1),
+ ),
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self._get_padding(kernel_size, 1),
+ ),
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self._get_padding(kernel_size, 1),
+ ),
+ ]
+ )
+
+ self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
+
+ self.activations = nn.ModuleList(
+ [TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)]
+ )
+
+ def _get_padding(self, kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+ def forward(self, hidden_states):
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
+ for conv1, conv2, act1, act2 in zip(self.convs1, self.convs2, acts1, acts2):
+ residual = hidden_states
+ hidden_states = act1(hidden_states)
+ hidden_states = conv1(hidden_states)
+ hidden_states = act2(hidden_states)
+ hidden_states = conv2(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+@auto_docstring(
+ custom_intro="""
+ The full Qwen2.5Omni Token2WavBigVGAN model. Which take mel spectrogram as input and predict waveform.
+ """
+)
+class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel):
+ config: Qwen2_5OmniBigVGANConfig
+
+ def __init__(self, config: Qwen2_5OmniBigVGANConfig):
+ super().__init__(config)
+ self.num_residual_blocks = len(config.resblock_kernel_sizes)
+ self.num_upsample_layers = len(config.upsample_rates)
+
+ self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 7, 1, padding=3)
+
+ # Removing extra ModuleList breaks official state dict
+ ups = [
+ nn.ModuleList(
+ [
+ nn.ConvTranspose1d(
+ config.upsample_initial_channel // (2**layer_idx),
+ config.upsample_initial_channel // (2 ** (layer_idx + 1)),
+ kernel_size,
+ stride,
+ padding=(kernel_size - stride) // 2,
+ )
+ ]
+ )
+ for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes))
+ ]
+ self.ups = nn.ModuleList(ups)
+
+ self.resblocks = nn.ModuleList(
+ [
+ AMPBlock(config.upsample_initial_channel // (2 ** (layer_idx + 1)), kernel_size, dilation)
+ for layer_idx in range(self.num_upsample_layers)
+ for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes)
+ ]
+ )
+
+ self.activation_post = TorchActivation1d(
+ activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers))
+ )
+ self.conv_post = nn.Conv1d(
+ config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False
+ )
+
+ def normalize_spectrogram(self, spectrogram, max_value, min_db):
+ return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value)
+
+ def amplitude_to_db(self, amplitude, min_db_level):
+ min_level = torch.exp(
+ torch.tensor(min_db_level / 20.0 * np.log(10), device=amplitude.device, dtype=amplitude.dtype)
+ )
+ return 20 * torch.log10(torch.clamp(amplitude, min=min_level))
+
+ def process_mel_spectrogram(self, mel_spectrogram):
+ amplitude_spectrum = torch.exp(mel_spectrogram)
+ decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20
+ return self.normalize_spectrogram(decibel_spectrum, 1, -115)
+
+ def forward(self, mel_spectrogram):
+ processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram)
+ hidden_representation = self.conv_pre(processed_spectrogram)
+
+ for layer_index in range(self.num_upsample_layers):
+ hidden_representation = self.ups[layer_index][0](hidden_representation)
+ residual_output = sum(
+ self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation)
+ for block_index in range(self.num_residual_blocks)
+ )
+ residual_output = residual_output / self.num_residual_blocks
+ hidden_representation = residual_output
+
+ hidden_representation = self.activation_post(hidden_representation)
+ output_waveform = self.conv_post(hidden_representation)
+ return torch.clamp(output_waveform, min=-1.0, max=1.0).squeeze().cpu()
+
+
+class RungeKutta4ODESolver:
+ def __init__(self, function, initial_value):
+ self.function = function
+ self.initial_value = initial_value
+
+ self._one_third = 1 / 3
+ self._two_thirds = 2 / 3
+
+ def _rk4_step(self, function, time_start, time_step, time_end, value_start, function_value_start=None):
+ k1 = function_value_start if function_value_start is not None else function(time_start, value_start)
+ k2 = function(time_start + time_step * self._one_third, value_start + time_step * k1 * self._one_third)
+ k3 = function(time_start + time_step * self._two_thirds, value_start + time_step * (k2 - k1 * self._one_third))
+ k4 = function(time_end, value_start + time_step * (k1 - k2 + k3))
+ return (k1 + 3 * (k2 + k3) + k4) * time_step / 8
+
+ def _compute_step(self, function, time_start, time_step, time_end, value_start):
+ function_value_start = function(time_start, value_start)
+ return self._rk4_step(
+ function, time_start, time_step, time_end, value_start, function_value_start=function_value_start
+ ), function_value_start
+
+ def _linear_interpolation(self, time_start, time_end, value_start, value_end, time_point):
+ if time_point == time_start:
+ return value_start
+ if time_point == time_end:
+ return value_end
+ weight = (time_point - time_start) / (time_end - time_start)
+ return value_start + weight * (value_end - value_start)
+
+ def integrate(self, time_points):
+ solution = torch.empty(
+ len(time_points),
+ *self.initial_value.shape,
+ dtype=self.initial_value.dtype,
+ device=self.initial_value.device,
+ )
+ solution[0] = self.initial_value
+
+ current_index = 1
+ current_value = self.initial_value
+ for time_start, time_end in zip(time_points[:-1], time_points[1:]):
+ time_step = time_end - time_start
+ delta_value, _ = self._compute_step(self.function, time_start, time_step, time_end, current_value)
+ next_value = current_value + delta_value
+
+ while current_index < len(time_points) and time_end >= time_points[current_index]:
+ solution[current_index] = self._linear_interpolation(
+ time_start, time_end, current_value, next_value, time_points[current_index]
+ )
+ current_index += 1
+
+ current_value = next_value
+
+ return solution
+
+
+@auto_docstring(
+ custom_intro="""
+ The full Qwen2.5Omni Token2WavDiT model. Which take speech tokens as input and predict mel spectrogram.
+ """
+)
+class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel):
+ config: Qwen2_5OmniDiTConfig
+ _no_split_modules = ["DiTDecoderLayer"]
+
+ def __init__(self, config: Qwen2_5OmniDiTConfig):
+ super().__init__(config)
+ self.mel_dim = config.mel_dim
+ self.repeats = config.repeats
+ self.time_embed = DiTTimestepEmbedding(config.hidden_size)
+
+ self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats)
+ self.input_embed = DiTInputEmbedding(config)
+
+ self.rotary_embed = Qwen2_5OmniDiTRotaryEmbedding(config.head_dim)
+
+ self.hidden_size = config.hidden_size
+ self.layers = config.num_hidden_layers
+ self.block_size = config.block_size
+ self.num_attention_heads = config.num_attention_heads
+
+ self.transformer_blocks = nn.ModuleList()
+ for i in range(config.num_hidden_layers):
+ self.transformer_blocks.append(
+ DiTDecoderLayer(
+ config,
+ look_ahead_block=1 if i in config.look_ahead_layers else 0,
+ look_backward_block=1 if i in config.look_backward_layers else 0,
+ )
+ )
+
+ self.norm_out = Qwen2_5_OmniAdaLayerNormZero_Final(config.hidden_size) # final modulation
+ self.proj_out = nn.Linear(config.hidden_size, config.mel_dim)
+
+ def _create_block_diff(self, hidden_states):
+ batch, seq_len = hidden_states.shape[0], hidden_states.shape[1]
+ block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length]
+
+ block_i = block_indices.unsqueeze(1) # [seq_length, 1]
+ block_j = block_indices.unsqueeze(0) # [1, seq_length]
+ block_diff = block_j - block_i # (n, n)
+
+ return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len)
+
+ def forward(
+ self,
+ hidden_states,
+ condition_vector,
+ speaker_embedding,
+ quantized_code,
+ time_step,
+ drop_audio_conditioning=False,
+ drop_code=False,
+ apply_cfg=True,
+ ):
+ batch_size = hidden_states.shape[0]
+ if time_step.ndim == 0:
+ time_step = time_step.repeat(batch_size)
+
+ # Compute embeddings
+ time_embedding = self.time_embed(time_step)
+ text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code)
+ text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None
+
+ hidden_states = self.input_embed(
+ hidden_states,
+ speaker_embedding,
+ condition_vector,
+ text_embedding,
+ drop_audio_cond=drop_audio_conditioning,
+ code_embed_uncond=text_embedding_unconditioned,
+ apply_cfg=apply_cfg,
+ )
+
+ # Compute positional encodings
+ position_embeddings = self.rotary_embed(hidden_states)
+ blockwise_difference = self._create_block_diff(hidden_states)
+
+ # Transformer blocks
+ for transformer_block in self.transformer_blocks:
+ hidden_states = transformer_block(
+ hidden_states,
+ time_embedding,
+ position_embeddings=position_embeddings,
+ block_diff=blockwise_difference,
+ )
+
+ hidden_states = self.norm_out(hidden_states, time_embedding)
+ output = self.proj_out(hidden_states)
+
+ return output
+
+ @torch.no_grad()
+ def sample(
+ self,
+ conditioning_vector,
+ reference_mel_spectrogram,
+ quantized_code,
+ num_steps=10,
+ guidance_scale=0.5,
+ sway_coefficient=-1.0,
+ ):
+ noise_initialization = torch.randn([1, 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype)
+ maximum_duration = quantized_code.shape[1] * self.repeats
+ initial_state = noise_initialization[:, :maximum_duration].to(quantized_code.device)
+ batch_size = reference_mel_spectrogram.shape[0]
+ conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1)
+
+ if batch_size != 1:
+ raise ValueError("Only batch size = 1 is currently supported")
+
+ def ode_function(time_step, hidden_states):
+ if guidance_scale < 1e-5:
+ prediction = self(
+ hidden_states=hidden_states,
+ speaker_embedding=conditioning_vector,
+ condition_vector=reference_mel_spectrogram,
+ quantized_code=quantized_code,
+ time_step=time_step,
+ drop_audio_conditioning=False,
+ drop_code=False,
+ )
+ return prediction
+
+ model_output = self(
+ hidden_states=hidden_states,
+ quantized_code=quantized_code,
+ speaker_embedding=conditioning_vector,
+ condition_vector=reference_mel_spectrogram,
+ time_step=time_step,
+ apply_cfg=True,
+ )
+ guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0)
+ return guided_prediction + (guided_prediction - null_prediction) * guidance_scale
+
+ initial_time = 0
+ time_embedding = torch.linspace(
+ initial_time, 1, num_steps, device=quantized_code.device, dtype=conditioning_vector.dtype
+ )
+
+ if sway_coefficient is not None:
+ time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding)
+
+ ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state)
+ solution_trajectory = ode_solver.integrate(time_embedding)
+
+ generated_waveform = solution_trajectory[-1]
+ generated_mel_spectrogram = generated_waveform.permute(0, 2, 1)
+ return generated_mel_spectrogram
+
+
+@auto_docstring(
+ custom_intro="""
+ The full Qwen2.5Omni Token2Wav model. Consists a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform.
+ """
+)
+class Qwen2_5OmniToken2WavModel(Qwen2_5OmniPreTrainedModel):
+ config: Qwen2_5OmniToken2WavConfig
+ base_model_prefix = "model"
+ _no_split_modules = ["Qwen2_5OmniToken2WavDiTModel", "Qwen2_5OmniToken2WavBigVGANModel"]
+
+ def __init__(self, config: Qwen2_5OmniToken2WavConfig):
+ super().__init__(config)
+ attn_impl = config._attn_implementation
+ if config._attn_implementation == "flash_attention_2":
+ logger.warning_once(
+ "Qwen2_5OmniToken2WavModel must inference with fp32, but flash_attention_2 only supports fp16 and bf16, "
+ "attention implementation of Qwen2_5OmniToken2WavModel will fallback to sdpa."
+ )
+ attn_impl = "sdpa"
+ elif config._attn_implementation == "eager":
+ logger.warning_once(
+ "Qwen2_5OmniToken2WavModel does not support eager attention implementation, fall back to sdpa"
+ )
+ attn_impl = "sdpa"
+ self.code2wav_dit_model = Qwen2_5OmniToken2WavDiTModel._from_config(
+ config.dit_config, attn_implementation=attn_impl
+ )
+ self.code2wav_bigvgan_model = Qwen2_5OmniToken2WavBigVGANModel._from_config(
+ config.bigvgan_config, attn_implementation=attn_impl
+ )
+
+ def forward(
+ self,
+ code,
+ conditioning,
+ reference_mel,
+ num_steps=10,
+ guidance_scale=0.5,
+ sway_coefficient=-1.0,
+ **kwargs,
+ ):
+ """Generates a waveform from input code and conditioning parameters."""
+
+ mel_spectrogram = self.code2wav_dit_model.sample(
+ conditioning,
+ reference_mel,
+ code,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ sway_coefficient=sway_coefficient,
+ )
+
+ waveform = self.code2wav_bigvgan_model(mel_spectrogram)
+
+ return waveform
+
+
+############################
+# Start Qwen2.5Omni #
+############################
+
+
+@auto_docstring(
+ custom_intro="""
+ The full Qwen2.5Omni model, a multimodal model composed of 3 sub-models:
+ - [`Qwen2_5OmniThinkerForConditionalGeneration`]:
+ a causal auto-regressive transformer takes text, audio, image, video as input and predict text tokens.
+ - [`Qwen2_5OmniTalkerForConditionalGeneration`]:
+ a causal auto-regressive transformer takes thinker hidden states and response as input and predict speech tokens.
+ - [`Qwen2_5OmniToken2WavModel`]:
+ a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform.
+ """
+)
+class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, GenerationMixin):
+ config: Qwen2_5OmniConfig
+ _no_split_modules = [
+ "Qwen2_5OmniTalkerForConditionalGeneration",
+ "Qwen2_5OmniToken2WavModel",
+ ]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.thinker = Qwen2_5OmniThinkerForConditionalGeneration(config.thinker_config)
+
+ self.has_talker = config.enable_audio_output
+ self.speaker_map = {}
+ if config.enable_audio_output:
+ self.enable_talker()
+ self.post_init()
+
+ def enable_talker(self):
+ self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config)
+ self.token2wav = Qwen2_5OmniToken2WavModel(self.config.token2wav_config)
+ self.token2wav.float()
+ self.has_talker = True
+
+ def load_speakers(self, path):
+ check_torch_load_is_safe()
+ for key, value in torch.load(path, weights_only=True).items():
+ self.speaker_map[key] = value
+ logger.info(f"Speaker {list(self.speaker_map.keys())} loaded")
+
+ def disable_talker(self):
+ if hasattr(self, "talker"):
+ del self.talker
+ if hasattr(self, "token2wav"):
+ del self.token2wav
+ self.has_talker = False
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path,
+ *model_args,
+ config=None,
+ cache_dir=None,
+ ignore_mismatched_sizes=False,
+ force_download=False,
+ local_files_only=False,
+ token=None,
+ revision="main",
+ use_safetensors=None,
+ weights_only=True,
+ **kwargs,
+ ):
+ model = super().from_pretrained(
+ pretrained_model_name_or_path,
+ *model_args,
+ config=config,
+ cache_dir=cache_dir,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ force_download=force_download,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ use_safetensors=use_safetensors,
+ weights_only=weights_only,
+ **kwargs,
+ )
+ spk_path = cached_file(
+ pretrained_model_name_or_path,
+ "spk_dict.pt",
+ subfolder=kwargs.pop("subfolder", None),
+ cache_dir=kwargs.pop("cache_dir", None),
+ force_download=kwargs.pop("force_download", False),
+ proxies=kwargs.pop("proxies", None),
+ resume_download=kwargs.pop("resume_download", None),
+ local_files_only=kwargs.pop("local_files_only", False),
+ token=kwargs.pop("use_auth_token", None),
+ revision=kwargs.pop("revision", None),
+ )
+ if spk_path is None:
+ raise ValueError(f"""{pretrained_model_name_or_path}/{spk_path} not exists""")
+ model.load_speakers(spk_path)
+
+ return model
+
+ @torch.no_grad()
+ # TODO: raushan, defaults should be saved in generation config
+ def generate(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ speaker: str = "Chelsie",
+ use_audio_in_video: bool = False,
+ return_audio: Optional[bool] = None,
+ thinker_max_new_tokens: int = 1024,
+ talker_max_new_tokens: int = 4096,
+ talker_do_sample: bool = True,
+ talker_top_k: int = 40,
+ talker_top_p: float = 0.8,
+ talker_temperature: float = 0.9,
+ talker_eos_token_id: list[int] = [8292, 8294],
+ talker_repetition_penalty: float = 1.05,
+ **kwargs,
+ ):
+ r"""
+ Generate text response and audio from input.
+
+ Args:
+ input_ids (`Optional[torch.Tensor]`, *optional*):
+ Input ids, should obtain from processor.
+ speaker (`str` , defaults to "Chelsie"):
+ Which speaker should be used in audio response.
+ use_audio_in_video (`bool`, defaults to False):
+ Whether or not use audio track in video, should same as the parameter in `process_audio_info`.
+ return_audio (`Optional[bool]`, *optional*):
+ Whether or not return response in audio format. When `return_audio=None`, this parameter is same as `config.enable_audio_output`.
+ kwargs (*optional*):
+ - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model.
+ - With a *thinker_*, *talker_*, *token2wav_* prefix, they will be input for the `generate` method of the
+ thinker, talker and token2wav respectively. It has the priority over the keywords without a prefix.
+ Returns:
+ When `return_audio=False`:
+ - **Text** (`torch.Tensor`): Generated text token sequence.
+ When `return_audio=True`:
+ - **Text** (`torch.Tensor`): Generated text token sequence.
+ - **Audio waveform** (`torch.Tensor`): Generated audio waveform.
+ """
+ if speaker not in self.speaker_map:
+ raise ValueError(f"{speaker} is not available, available speakers: {self.speaker_map.keys()}")
+ if return_audio and not self.has_talker:
+ raise ValueError(
+ "Cannot use talker when talker module not initialized. Use `enable_talker` method or set enable_talker in config to enable talker."
+ )
+ if return_audio is None:
+ return_audio = self.has_talker
+ if input_ids.shape[0] != 1 and return_audio:
+ raise NotImplementedError("Qwen2.5-Omni currently does not support batched inference with audio output")
+
+ shared_kwargs = {"use_audio_in_video": use_audio_in_video}
+ thinker_kwargs = {
+ "max_new_tokens": thinker_max_new_tokens,
+ }
+ talker_kwargs = {
+ "max_new_tokens": talker_max_new_tokens,
+ "do_sample": talker_do_sample,
+ "top_k": talker_top_k,
+ "top_p": talker_top_p,
+ "temperature": talker_temperature,
+ "eos_token_id": talker_eos_token_id,
+ "repetition_penalty": talker_repetition_penalty,
+ }
+ token2wav_kwargs = {}
+
+ for key, value in kwargs.items():
+ if key.startswith("thinker_"):
+ thinker_kwargs[key[len("thinker_") :]] = value
+ elif key.startswith("talker_"):
+ talker_kwargs[key[len("talker_") :]] = value
+ elif key.startswith("token2wav_"):
+ token2wav_kwargs[key[len("token2wav_") :]] = value
+ # Process special input values
+ elif key == "feature_attention_mask":
+ thinker_kwargs[key] = value
+ talker_kwargs["audio_feature_lengths"] = torch.sum(value, dim=1)
+ elif key == "input_features" or key == "attention_mask":
+ thinker_kwargs[key] = value
+ # Put other key to shared kwargs
+ else:
+ shared_kwargs[key] = value
+
+ # Merge kwargs
+ for key, value in shared_kwargs.items():
+ if key not in thinker_kwargs:
+ thinker_kwargs[key] = value
+ if key not in talker_kwargs:
+ talker_kwargs[key] = value
+ if key not in token2wav_kwargs:
+ token2wav_kwargs[key] = value
+ speaker_params = self.speaker_map[speaker]
+
+ # 1. Generate from thinker module
+ generate_audio = return_audio and self.has_talker
+ if generate_audio:
+ thinker_kwargs["output_hidden_states"] = True
+ thinker_kwargs["return_dict_in_generate"] = True
+
+ thinker_result = self.thinker.generate(input_ids=input_ids, **thinker_kwargs)
+
+ if not generate_audio:
+ return thinker_result
+
+ # 2. Generate speech tokens from talker module
+ embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(input_ids.device)
+ if thinker_kwargs.get("input_features") is not None:
+ audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index
+ audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
+ audio_mask_tensor = torch.zeros(
+ [audio_ids_mask.sum(), embeds_to_talker.shape[-1]],
+ dtype=embeds_to_talker.dtype,
+ device=input_ids.device,
+ )
+ embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor)
+ if thinker_kwargs.get("pixel_values") is not None:
+ image_ids_mask = input_ids == self.config.thinker_config.image_token_index
+ image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
+ image_mask_tensor = torch.zeros(
+ [image_ids_mask.sum(), embeds_to_talker.shape[-1]],
+ dtype=embeds_to_talker.dtype,
+ device=input_ids.device,
+ )
+ embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor)
+ if thinker_kwargs.get("pixel_values_videos") is not None:
+ video_ids_mask = input_ids == self.config.thinker_config.video_token_index
+ video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
+ video_mask_tensor = torch.zeros(
+ [video_ids_mask.sum(), embeds_to_talker.shape[-1]],
+ dtype=embeds_to_talker.dtype,
+ device=input_ids.device,
+ )
+ embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor)
+
+ processed_thinker_hidden = (
+ (embeds_to_talker,) + thinker_result.hidden_states[0][1:],
+ ) + thinker_result.hidden_states[1:]
+ thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(input_ids.device)
+ thinker_token_embeds = [
+ token_hidden_states[0].to(input_ids.device) for token_hidden_states in processed_thinker_hidden
+ ]
+ thinker_hidden_states = [
+ token_hidden_states[-1].to(input_ids.device) for token_hidden_states in processed_thinker_hidden
+ ]
+
+ talker_text_bos_token = speaker_params["bos_token"]
+ talker_input_text_ids = torch.cat(
+ [
+ input_ids,
+ torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device),
+ thinker_generate_ids[:, :1],
+ ],
+ dim=-1,
+ )
+
+ talker_input_ids = torch.cat(
+ [
+ torch.full_like(input_ids, fill_value=self.talker.codec_mask_token),
+ torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=input_ids.device),
+ torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=input_ids.device),
+ ],
+ dim=1,
+ )
+
+ thinker_embed_tokens = self.thinker.get_input_embeddings()
+ thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
+ talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0]
+ talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device)
+ talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(input_ids.device)
+ talker_inputs_embeds = torch.cat(
+ [
+ talker_inputs_embeds,
+ talker_text_bos_embed,
+ thinker_reply_part[:, :1, :],
+ ],
+ dim=1,
+ )
+
+ eos_token = torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=input_ids.device)
+ eos_embedding = thinker_embed_tokens(eos_token).to(input_ids.device)
+
+ pad_token = torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=input_ids.device)
+ pad_embedding = thinker_embed_tokens(pad_token).to(input_ids.device)
+
+ thinker_reply_part = torch.cat(
+ [
+ thinker_reply_part[:, 1:, :],
+ eos_embedding,
+ pad_embedding,
+ ],
+ dim=1,
+ )
+
+ talker_attention_mask = None
+ if "attention_mask" in kwargs:
+ talker_attention_mask = torch.cat(
+ [kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1
+ ).to(input_ids.device)
+
+ talker_result = self.talker.generate(
+ input_ids=talker_input_ids,
+ input_text_ids=talker_input_text_ids,
+ thinker_reply_part=thinker_reply_part,
+ inputs_embeds=talker_inputs_embeds,
+ attention_mask=talker_attention_mask,
+ suppress_tokens=[self.talker.codec_bos_token],
+ **{k: (v.to(input_ids.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()},
+ )
+ talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1]
+
+ # 3. Generate wavs from code
+ if self.token2wav.dtype != torch.float:
+ self.token2wav.float()
+
+ wav = self.token2wav(
+ talker_generate_codes.to(input_ids.device),
+ conditioning=speaker_params["cond"].to(input_ids.device).float(),
+ reference_mel=speaker_params["ref_mel"].to(input_ids.device).float(),
+ **token2wav_kwargs,
+ )
+
+ return thinker_result.sequences, wav.float()
+
+
+__all__ = [
+ "Qwen2_5OmniConfig",
+ "Qwen2_5OmniThinkerConfig",
+ "Qwen2_5OmniTalkerConfig",
+ "Qwen2_5OmniToken2WavConfig",
+ "Qwen2_5OmniForConditionalGeneration",
+ "Qwen2_5OmniThinkerTextModel",
+ "Qwen2_5OmniThinkerForConditionalGeneration",
+ "Qwen2_5OmniTalkerModel",
+ "Qwen2_5OmniTalkerForConditionalGeneration",
+ "Qwen2_5OmniToken2WavDiTModel",
+ "Qwen2_5OmniToken2WavBigVGANModel",
+ "Qwen2_5OmniToken2WavModel",
+ "Qwen2_5OmniPreTrainedModel",
+ "Qwen2_5OmniPreTrainedModelForConditionalGeneration",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fcbb0c535f92b912e4638a7b2b5f6fd6bb58f81
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py
@@ -0,0 +1,360 @@
+# coding=utf-8
+# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for Qwen2.5Omni.
+"""
+
+import logging
+import re
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
+from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput
+from ...video_utils import VideoInput
+
+
+class Qwen2_5_OmniVideosKwargs(VideosKwargs):
+ fps: Optional[list[Union[int, float]]]
+ use_audio_in_video: Optional[bool]
+ seconds_per_chunk: Optional[float]
+ position_id_per_seconds: Optional[int]
+ min_pixels: Optional[int]
+ max_pixels: Optional[int]
+ patch_size: Optional[int]
+ temporal_patch_size: Optional[int]
+ merge_size: Optional[int]
+
+
+class Qwen2_5_OmniImagesKwargs(ImagesKwargs):
+ min_pixels: Optional[int]
+ max_pixels: Optional[int]
+ patch_size: Optional[int]
+ temporal_patch_size: Optional[int]
+ merge_size: Optional[int]
+
+
+class Qwen2_5OmniProcessorKwargs(ProcessingKwargs, total=False):
+ videos_kwargs: Qwen2_5_OmniVideosKwargs
+ images_kwargs: Qwen2_5_OmniImagesKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "padding_side": "left",
+ },
+ "videos_kwargs": {
+ "seconds_per_chunk": 2.0,
+ "position_id_per_seconds": 25,
+ "use_audio_in_video": False,
+ "size": {
+ "shortest_edge": 128 * 28 * 28,
+ "longest_edge": 768 * 28 * 28,
+ },
+ },
+ "audio_kwargs": {
+ "sampling_rate": 16000,
+ "padding": "max_length",
+ "return_attention_mask": True,
+ },
+ }
+
+
+class Qwen2_5OmniProcessor(ProcessorMixin):
+ r"""
+ Constructs a Qwen2.5Omni processor.
+ [`Qwen2_5OmniProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`], [`WhisperFeatureExtractor`], and [`Qwen2TokenizerFast`]. See the
+ [`~Qwen2_5OmniProcessor.__call__`] and [`~Qwen2_5OmniProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
+ The image processor.
+ video_processor ([`Qwen2VLVideoProcessor`], *optional*):
+ The video processor.
+ feature_extractor ([`WhisperFeatureExtractor`], *optional*):
+ The audio feature extractor.
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
+ The text tokenizer.
+ chat_template (`Optional[str]`, *optional*):
+ The Jinja template to use for formatting the conversation. If not provided, the default chat template is used.
+ """
+
+ attributes = ["image_processor", "video_processor", "feature_extractor", "tokenizer"]
+ image_processor_class = "AutoImageProcessor"
+ video_processor_class = "AutoVideoProcessor"
+ feature_extractor_class = "WhisperFeatureExtractor"
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
+
+ def __init__(
+ self, image_processor=None, video_processor=None, feature_extractor=None, tokenizer=None, chat_template=None
+ ):
+ super().__init__(image_processor, video_processor, feature_extractor, tokenizer, chat_template=chat_template)
+ self.image_token = self.tokenizer.image_token
+ self.audio_token = self.tokenizer.audio_token
+ self.video_token = self.tokenizer.video_token
+ self.vision_bos_token = self.tokenizer.vision_bos_token
+ self.vision_eos_token = self.tokenizer.vision_eos_token
+ self.audio_bos_token = self.tokenizer.audio_bos_token
+ self.audio_eos_token = self.tokenizer.audio_eos_token
+
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ images: Optional[ImageInput] = None,
+ videos: Optional[VideoInput] = None,
+ audio: Optional[AudioInput] = None,
+ **kwargs: Unpack[Qwen2_5OmniProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the audio(s), this method forwards the `audio` and `kwargs` arguments to
+ WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] if `audio` is not `None`. To prepare the vision inputs,
+ this method forwards the `vision_infos` and `kwargs` arguments to Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`]
+ if `vision_infos` is not `None`. Please refer to the doctsring
+ of the above two methods for more information.
+
+ Args:
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
+ tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
+ audio (`np.ndarray`, `list[np.ndarray]`):
+ The audio or batch of audio to be prepared. Each audio can be a NumPy array.
+ """
+
+ if text is None:
+ raise ValueError("You need to specify either a `text` input to process.")
+
+ output_kwargs = self._merge_kwargs(
+ Qwen2_5OmniProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ seconds_per_chunk = output_kwargs["videos_kwargs"].pop("seconds_per_chunk")
+ position_id_per_seconds = output_kwargs["videos_kwargs"].pop("position_id_per_seconds")
+ use_audio_in_video = output_kwargs["videos_kwargs"].pop("use_audio_in_video")
+
+ if audio is not None:
+ output_kwargs["audio_kwargs"]["padding"] = "max_length" # Support "max_length" padding only here
+ audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
+ audio_inputs["feature_attention_mask"] = audio_inputs.pop(
+ "attention_mask"
+ ) # rename feature_attention_mask to prevent conflicts later on
+ audio_inputs["input_features"] = audio_inputs.pop(
+ "input_features"
+ ) # rename input_features to prevent conflicts later on
+ input_lengths = (audio_inputs["feature_attention_mask"].sum(-1) - 1) // 2 + 1
+ audio_lengths = iter((input_lengths - 2) // 2 + 1)
+ else:
+ audio_inputs = {}
+ audio_lengths = iter([])
+
+ if images is not None:
+ images_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ image_grid_thw = iter(images_inputs["image_grid_thw"])
+ else:
+ images_inputs = {}
+ image_grid_thw = iter([])
+
+ if videos is not None:
+ videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
+
+ fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
+ video_grid_thw = videos_inputs["video_grid_thw"]
+ second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
+ videos_inputs["video_second_per_grid"] = second_per_grid_ts
+
+ video_grid_thw = iter(video_grid_thw)
+ video_second_per_grid = iter(second_per_grid_ts)
+ else:
+ videos_inputs = {}
+ video_grid_thw = iter([])
+ video_second_per_grid = iter([])
+
+ if not isinstance(text, list):
+ text = [text]
+
+ if images is not None or videos is not None or audio is not None:
+ text = self.replace_multimodal_special_tokens(
+ text,
+ audio_lengths,
+ image_grid_thw,
+ video_grid_thw,
+ video_second_per_grid=video_second_per_grid,
+ use_audio_in_video=use_audio_in_video,
+ position_id_per_seconds=position_id_per_seconds,
+ seconds_per_chunk=seconds_per_chunk,
+ )
+
+ texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+
+ return BatchFeature(
+ data={**texts_inputs, **images_inputs, **videos_inputs, **audio_inputs},
+ tensor_type=kwargs.get("return_tensors"),
+ )
+
+ def replace_multimodal_special_tokens(
+ self,
+ text,
+ audio_lengths,
+ image_grid_thw,
+ video_grid_thw,
+ video_second_per_grid,
+ use_audio_in_video,
+ position_id_per_seconds,
+ seconds_per_chunk,
+ ):
+ # Extend mm token length
+ merge_length_image = self.image_processor.merge_size**2
+ merge_length_video = self.video_processor.merge_size**2
+
+ processed_text = []
+ for sample in text:
+ positions = []
+ special_tokens = [re.escape(tok) for tok in [self.audio_token, self.image_token, self.video_token]]
+ pattern = "|".join(special_tokens)
+ positions = sorted([(match.start(), match.group()) for match in re.finditer(pattern, sample)])
+ positions.sort(key=lambda x: x[0])
+
+ for _, special_token in positions:
+ if special_token == self.audio_token:
+ sample = sample.replace(self.audio_token, "<|audio_placeholder|>" * next(audio_lengths), 1)
+ elif special_token == self.image_token:
+ image_seq_length = next(image_grid_thw).prod() // merge_length_image
+ sample = sample.replace(self.image_token, "<|image_placeholder|>" * image_seq_length, 1)
+ elif special_token == self.video_token:
+ if not use_audio_in_video:
+ video_seq_length = next(video_grid_thw).prod() // merge_length_video
+ sample = sample.replace(self.video_token, "<|video_placeholder|>" * video_seq_length, 1)
+ else:
+ audio_token_indices = np.arange(next(audio_lengths))
+ curr_video_grid_thw = next(video_grid_thw)
+ height = curr_video_grid_thw[1] // self.video_processor.merge_size
+ width = curr_video_grid_thw[2] // self.video_processor.merge_size
+ video_token_indices = np.arange(curr_video_grid_thw[0]).reshape(-1, 1, 1)
+ video_token_indices = np.broadcast_to(
+ video_token_indices, (video_token_indices.shape[0], height, width)
+ ).reshape(-1)
+ video_token_indices = (
+ video_token_indices * next(video_second_per_grid) * position_id_per_seconds
+ )
+
+ tokens_per_chunk = int(position_id_per_seconds * seconds_per_chunk)
+ video_chunk_indexes = self.get_chunked_index(video_token_indices, tokens_per_chunk)
+ audio_chunk_indexes = self.get_chunked_index(audio_token_indices, tokens_per_chunk)
+
+ placeholder_string = self.vision_bos_token + self.audio_bos_token
+ for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
+ if j < len(video_chunk_indexes):
+ video_seq_length = video_chunk_indexes[j][1] - video_chunk_indexes[j][0]
+ placeholder_string += "<|video_placeholder|>" * video_seq_length
+ if j < len(audio_chunk_indexes):
+ audio_seq_length = audio_chunk_indexes[j][1] - audio_chunk_indexes[j][0]
+ placeholder_string += "<|audio_placeholder|>" * audio_seq_length
+ placeholder_string += self.audio_eos_token + self.vision_eos_token
+ sample = sample.replace(
+ self.vision_bos_token + self.video_token + self.vision_eos_token,
+ placeholder_string,
+ 1,
+ )
+
+ sample = sample.replace("<|audio_placeholder|>", self.audio_token)
+ sample = sample.replace("<|image_placeholder|>", self.image_token)
+ sample = sample.replace("<|video_placeholder|>", self.video_token)
+ processed_text.append(sample)
+ return processed_text
+
+ def get_chunked_index(self, token_indices: np.ndarray, tokens_per_chunk: int) -> list[tuple[int, int]]:
+ """
+ Splits token index list into chunks based on token value ranges.
+
+ Given a list of token indices, returns a list of (start, end) index tuples representing
+ slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`.
+
+ For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that:
+ - the first chunk contains token values < 1000,
+ - the second chunk contains values >= 1000 and < 2000, and so on.
+
+ Parameters:
+ token_indices (`np.ndarray`): A monotonically increasing list of token index values.
+ t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
+
+ Returns:
+ `list[tuple[int, int]]`: A list of tuples, each representing the start (inclusive)
+ and end (exclusive) indices of a chunk in `token_indices`.
+ """
+
+ def _iter():
+ i, start_idx = 0, 0 # skip bos token
+ current_chunk = 1
+ while i < len(token_indices): # skip eos token
+ if token_indices[i] >= current_chunk * tokens_per_chunk:
+ yield (start_idx, i)
+ start_idx = i
+ current_chunk += 1
+ i += 1
+ yield (start_idx, len(token_indices))
+
+ return list(_iter())
+
+ def apply_chat_template(self, conversations, chat_template=None, **kwargs):
+ is_batched = False
+ if isinstance(conversations[0], dict):
+ conversations = [conversations]
+ is_batched = True
+
+ for conversation in conversations:
+ if (
+ conversation[0]["role"] != "system"
+ or conversation[0]["content"][0]["text"]
+ != "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."
+ ):
+ logging.warning(
+ "System prompt modified, audio output may not work as expected. "
+ + "Audio output mode only works when using default system prompt 'You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.'"
+ )
+ if is_batched:
+ conversations = conversations[0]
+
+ return super().apply_chat_template(conversations, chat_template, **kwargs)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ feature_extractor_input_names = self.feature_extractor.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(
+ dict.fromkeys(
+ tokenizer_input_names
+ + feature_extractor_input_names
+ + image_processor_input_names
+ + ["feature_attention_mask"]
+ + ["video_second_per_grid"]
+ )
+ )
+
+
+__all__ = ["Qwen2_5OmniProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a9f44a7a05fc65136552948be9acef53e5ea106
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_qwen2_5_vl import *
+ from .modeling_qwen2_5_vl import *
+ from .processing_qwen2_5_vl import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..c65fc5edb3c1a69c365f1f4eb2928e3fe45e8550
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py
@@ -0,0 +1,364 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen2_5_vl.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...modeling_rope_utils import rope_config_validation
+
+
+class Qwen2_5_VLVisionConfig(PretrainedConfig):
+ model_type = "qwen2_5_vl"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=32,
+ hidden_size=3584,
+ hidden_act="silu",
+ intermediate_size=3420,
+ num_heads=16,
+ in_channels=3,
+ patch_size=14,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ tokens_per_second=4,
+ window_size=112,
+ out_hidden_size=3584,
+ fullatt_block_indexes=[7, 15, 23, 31],
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.tokens_per_second = tokens_per_second
+ self.window_size = window_size
+ self.fullatt_block_indexes = fullatt_block_indexes
+ self.out_hidden_size = out_hidden_size
+ self.initializer_range = initializer_range
+
+
+class Qwen2_5_VLTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2_5_VLTextModel`]. It is used to instantiate a
+ Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 152064):
+ Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Qwen2_5_VLModel`]
+ hidden_size (`int`, *optional*, defaults to 8192):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 29568):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 80):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 64):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
+ Whether to use sliding window attention.
+ sliding_window (`int`, *optional*, defaults to 4096):
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
+ max_window_layers (`int`, *optional*, defaults to 80):
+ The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
+ additional layer afterwards will use SWA (Sliding Window Attention).
+ layer_types (`list`, *optional*):
+ Attention pattern for each layer.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+
+ ```python
+ >>> from transformers import Qwen2_5_VLTextModel, Qwen2_5_VLConfig
+
+ >>> # Initializing a Qwen2_5_VL style configuration
+ >>> configuration = Qwen2_5_VLConfig()
+
+ >>> # Initializing a model from the Qwen2-VL-7B style configuration
+ >>> model = Qwen2_5_VLTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen2_5_vl_text"
+ base_config_key = "text_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `Qwen2_5_VL`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=152064,
+ hidden_size=8192,
+ intermediate_size=29568,
+ num_hidden_layers=80,
+ num_attention_heads=64,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-05,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=1000000.0,
+ use_sliding_window=False,
+ sliding_window=4096,
+ max_window_layers=80,
+ layer_types=None,
+ attention_dropout=0.0,
+ rope_scaling=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.use_sliding_window = use_sliding_window
+ self.sliding_window = sliding_window if self.use_sliding_window else None
+ self.max_window_layers = max_window_layers
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+ self.rope_scaling = rope_scaling
+
+ self.layer_types = layer_types
+ if self.layer_types is None:
+ self.layer_types = [
+ "sliding_attention"
+ if self.sliding_window is not None and i >= self.max_window_layers
+ else "full_attention"
+ for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ # and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
+ # one can set it to "linear"/"dynamic" etc. to have scaled RoPE
+ # TODO: @raushan update config in the hub
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ if self.rope_scaling["type"] == "mrope":
+ self.rope_scaling["type"] = "default"
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self, ignore_keys={"mrope_section"})
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+class Qwen2_5_VLConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
+ Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 151655):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 151656):
+ The video token index to encode the image prompt.
+ vision_start_token_id (`int`, *optional*, defaults to 151652):
+ The token index to denote start of vision input.
+ vision_end_token_id (`int`, *optional*, defaults to 151653):
+ The token index to denote end of vision input.
+
+ ```python
+ >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
+
+ >>> # Initializing a Qwen2_5_VL style configuration
+ >>> configuration = Qwen2_5_VLConfig()
+
+ >>> # Initializing a model from the Qwen2-VL-7B style configuration
+ >>> model = Qwen2_5_VLForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen2_5_vl"
+ sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=151655,
+ video_token_id=151656,
+ vision_start_token_id=151652,
+ vision_end_token_id=151653,
+ **kwargs,
+ ):
+ # We need to init super() here so that it does not reset values
+ # that are in text config to the BaseClass defaults. The Base
+ # config has many text related defaults and not all defaults are same as for `Qwen2_5_VLTextConfig`
+ super().__init__(**kwargs)
+
+ if isinstance(vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+
+ if isinstance(text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ # For BC use all kwargs to init `TextConfig`
+ self.text_config = self.sub_configs["text_config"](**kwargs)
+
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+
+ # Attention implementation to use. It sets it recursively on sub-configs so we call it again in the end
+ self._attn_implementation = kwargs.pop("attn_implementation", None)
+
+ def __setattr__(self, key, value):
+ if (
+ (text_config := super().__getattribute__("__dict__").get("text_config")) is not None
+ and key not in ["_name_or_path", "model_type", "dtype", "_attn_implementation_internal"]
+ and key in text_config.__dict__
+ ):
+ setattr(text_config, key, value)
+ else:
+ super().__setattr__(key, value)
+
+ def __getattribute__(self, key):
+ if "text_config" in super().__getattribute__("__dict__") and key not in [
+ "_name_or_path",
+ "model_type",
+ "dtype",
+ "_attn_implementation_internal",
+ ]:
+ text_config = super().__getattribute__("text_config")
+ if key in text_config.__dict__:
+ return getattr(text_config, key)
+
+ return super().__getattribute__(key)
+
+
+__all__ = ["Qwen2_5_VLConfig", "Qwen2_5_VLTextConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9857455192227369ea73e08268b2fc59c865f94
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
@@ -0,0 +1,1724 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen2_5_vl.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
+from ...utils.deprecation import deprecate_kwarg
+from ..qwen2.modeling_qwen2 import Qwen2RMSNorm
+from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen2_5_VLMLP(nn.Module):
+ def __init__(self, config, bias: bool = False):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+class Qwen2_5_VisionPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 14,
+ temporal_patch_size: int = 2,
+ in_channels: int = 3,
+ embed_dim: int = 1152,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.in_channels = in_channels
+ self.embed_dim = embed_dim
+
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class Qwen2_5_VisionRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+class Qwen2_5_VLPatchMerger(nn.Module):
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
+ super().__init__()
+ self.hidden_size = context_dim * (spatial_merge_size**2)
+ self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
+ self.mlp = nn.Sequential(
+ nn.Linear(self.hidden_size, self.hidden_size),
+ nn.GELU(),
+ nn.Linear(self.hidden_size, dim),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
+ return x
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb_vision(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
+ orig_q_dtype = q.dtype
+ orig_k_dtype = k.dtype
+ q, k = q.float(), k.float()
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ q_embed = q_embed.to(orig_q_dtype)
+ k_embed = k_embed.to(orig_k_dtype)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Qwen2_5_VLVisionAttention(nn.Module):
+ def __init__(self, config: Qwen2_5_VLVisionConfig) -> None:
+ super().__init__()
+ self.dim = config.hidden_size
+ self.num_heads = config.num_heads
+ self.head_dim = self.dim // self.num_heads
+ self.num_key_value_groups = 1 # needed for eager attention
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
+ self.proj = nn.Linear(self.dim, self.dim)
+ self.scaling = self.head_dim**-0.5
+ self.config = config
+ self.attention_dropout = 0.0
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ query_states, key_states, value_states = (
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ )
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
+
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ if self.config._attn_implementation == "flash_attention_2":
+ # Flash Attention 2: Use cu_seqlens for variable length attention
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+ attn_output, _ = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ cu_seq_lens_q=cu_seqlens,
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+ else:
+ # Other implementations: Process each chunk separately
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+ splits = [
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
+ ]
+
+ attn_outputs = [
+ attention_interface(
+ self,
+ q,
+ k,
+ v,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ is_causal=False,
+ **kwargs,
+ )[0]
+ for q, k, v in zip(*splits)
+ ]
+ attn_output = torch.cat(attn_outputs, dim=1)
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
+ super().__init__()
+ self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
+ self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
+ self.attn = Qwen2_5_VLVisionAttention(config=config)
+ self.mlp = Qwen2_5_VLMLP(config, bias=True)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+@auto_docstring
+class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
+ config: Qwen2_5_VLConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+
+
+class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
+ config: Qwen2_5_VLVisionConfig
+ _no_split_modules = ["Qwen2_5_VLVisionBlock"]
+
+ def __init__(self, config, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.spatial_merge_size = config.spatial_merge_size
+ self.patch_size = config.patch_size
+ self.fullatt_block_indexes = config.fullatt_block_indexes
+ self.window_size = config.window_size
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
+
+ self.patch_embed = Qwen2_5_VisionPatchEmbed(
+ patch_size=config.patch_size,
+ temporal_patch_size=config.temporal_patch_size,
+ in_channels=config.in_channels,
+ embed_dim=config.hidden_size,
+ )
+
+ head_dim = config.hidden_size // config.num_heads
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList([Qwen2_5_VLVisionBlock(config) for _ in range(config.depth)])
+ self.merger = Qwen2_5_VLPatchMerger(
+ dim=config.out_hidden_size,
+ context_dim=config.hidden_size,
+ spatial_merge_size=config.spatial_merge_size,
+ )
+ self.gradient_checkpointing = False
+
+ def rot_pos_emb(self, grid_thw):
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb
+
+ def get_window_index(self, grid_thw):
+ window_index: list = []
+ cu_window_seqlens: list = [0]
+ window_index_id = 0
+ vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
+
+ for grid_t, grid_h, grid_w in grid_thw:
+ llm_grid_h, llm_grid_w = (
+ grid_h // self.spatial_merge_size,
+ grid_w // self.spatial_merge_size,
+ )
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
+ index_padded = index_padded.reshape(
+ grid_t,
+ num_windows_h,
+ vit_merger_window_size,
+ num_windows_w,
+ vit_merger_window_size,
+ )
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
+ grid_t,
+ num_windows_h * num_windows_w,
+ vit_merger_window_size,
+ vit_merger_window_size,
+ )
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
+ index_padded = index_padded.reshape(-1)
+ index_new = index_padded[index_padded != -100]
+ window_index.append(index_new + window_index_id)
+ cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
+ window_index = torch.cat(window_index, dim=0)
+
+ return window_index, cu_window_seqlens
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
+ The final hidden states of the model.
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Returns:
+ `torch.Tensor`: hidden_states.
+ """
+ hidden_states = self.patch_embed(hidden_states)
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
+ cu_window_seqlens = torch.tensor(
+ cu_window_seqlens,
+ device=hidden_states.device,
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
+
+ seq_len, _ = hidden_states.size()
+ hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ hidden_states = hidden_states[window_index, :, :]
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ for layer_num, blk in enumerate(self.blocks):
+ if layer_num in self.fullatt_block_indexes:
+ cu_seqlens_now = cu_seqlens
+ else:
+ cu_seqlens_now = cu_window_seqlens
+
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens_now,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.merger(hidden_states)
+ reverse_indices = torch.argsort(window_index)
+ hidden_states = hidden_states[reverse_indices, :]
+
+ return hidden_states
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class Qwen2_5_VLModelOutputWithPast(ModelOutput):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+class Qwen2_5_VLRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Qwen2_5_VLTextConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, Qwen2_5_VL has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class Qwen2MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
+
+ Explanation:
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
+ difference with modern LLMs.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ mrope_section(`List(int)`):
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ mrope_section = mrope_section * 2
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class Qwen2_5_VLAttention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
+ def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+ self.rope_scaling = config.rope_scaling
+ self.scaling = self.head_dim**-0.5
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+
+ self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
+ )
+
+ if past_key_values is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ position_ids=position_ids, # pass positions for FA2
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Qwen2_5_VLDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
+ logger.warning_once(
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
+ "unexpected results may be encountered."
+ )
+ self.self_attn = Qwen2_5_VLAttention(config, layer_idx)
+
+ self.mlp = Qwen2MLP(config)
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.attention_type = config.layer_types[layer_idx]
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
+ config: Qwen2_5_VLTextConfig
+
+ def __init__(self, config: Qwen2_5_VLTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # torch.jit.trace() doesn't support cache objects in the output
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # the hard coded `3` is for temporal, height and width.
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
+ # where each dim indicates visual spatial positions for temporal/height/width grids.
+ # There are two scenarios when FA2-like packed masking might be activated.
+ # 1. User specifically passed packed `position_ids` and no attention mask.
+ # In this case we expect the useer to create correct position ids for all 3 grids
+ # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
+ # 2. User runs forward with no attention mask and no position ids. In this case, position ids
+ # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
+ # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
+ # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
+ text_position_ids = position_ids[0]
+ position_ids = position_ids[1:]
+ else:
+ # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids
+ text_position_ids = None
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ # Prepare mask arguments
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": text_position_ids,
+ }
+ # Create the masks
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ }
+ # The sliding window alternating layers are not always activated depending on the config
+ if self.has_sliding_layers:
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ position_ids=text_position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+@auto_docstring
+class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
+ base_model_prefix = ""
+ _checkpoint_conversion_mapping = {"^model": "language_model"}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: Qwen2_5_VLConfig
+ _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
+ self.language_model = Qwen2_5_VLTextModel._from_config(config.text_config)
+ self.rope_deltas = None # cache rope_deltas here
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_rope_index(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ second_per_grid_ts: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
+
+ Explanation:
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
+
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
+ Examples:
+ input_ids: [T T T T T], here T is for text.
+ temporal position_ids: [0, 1, 2, 3, 4]
+ height position_ids: [0, 1, 2, 3, 4]
+ width position_ids: [0, 1, 2, 3, 4]
+
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
+ and 1D rotary position embedding for text part.
+ Examples:
+ Temporal (Time): 3 patches, representing different segments of the video in time.
+ Height: 2 patches, dividing each frame vertically.
+ Width: 2 patches, dividing each frame horizontally.
+ We also have some important parameters:
+ fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
+ tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
+ temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
+ interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
+ vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
+ text temporal position_ids: [101, 102, 103, 104, 105]
+ text height position_ids: [101, 102, 103, 104, 105]
+ text width position_ids: [101, 102, 103, 104, 105]
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ Returns:
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
+ """
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is not None:
+ attention_mask = attention_mask == 1
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_index, video_index = 0, 0
+ for i, input_ids in enumerate(total_input_ids):
+ if attention_mask is not None:
+ input_ids = input_ids[attention_mask[i]]
+ image_nums, video_nums = 0, 0
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ second_per_grid_t = 0
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ if second_per_grid_ts is not None:
+ second_per_grid_t = second_per_grid_ts[video_index]
+ else:
+ second_per_grid_t = 1.0
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+ text_len = ed - st
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ range_tensor = torch.arange(llm_grid_t).view(-1, 1)
+ expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
+
+ ## normalize type, send to device.
+ second_per_grid_t = torch.as_tensor(
+ second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device
+ )
+
+ time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
+
+ time_tensor_long = time_tensor.long()
+ t_index = time_tensor_long.flatten()
+
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ if attention_mask is not None:
+ position_ids[..., i, attention_mask[i]] = llm_positions.to(position_ids.device)
+ else:
+ position_ids[..., i, :] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=input_ids.device)
+ return position_ids, mrope_position_deltas
+ else:
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
+ else:
+ position_ids = (
+ torch.arange(input_ids.shape[1], device=input_ids.device)
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ """
+ Encodes videos into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
+ split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ video_embeds = torch.split(video_embeds, split_sizes)
+ return video_embeds
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ pixel_values = pixel_values.type(self.visual.dtype)
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ image_embeds = torch.split(image_embeds, split_sizes)
+ return image_embeds
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: Optional[torch.FloatTensor] = None,
+ video_features: Optional[torch.FloatTensor] = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.video_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
+ raise ValueError(
+ f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
+ )
+
+ return special_image_mask, special_video_mask
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ second_per_grid_ts: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Qwen2_5_VLModelOutputWithPast]:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw)
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ _, video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ if position_ids is None:
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # When compiling, we can't check tensor values thus we check only input length
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
+ # models currently cannot do asssisted decoding
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
+ (input_ids is not None and input_ids.shape[1] != 1)
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
+ )
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
+ (cache_position is not None and cache_position[0] == 0)
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
+ )
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ attention_mask=attention_mask,
+ )
+ self.rope_deltas = rope_deltas
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
+ if cache_position is not None:
+ delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
+ else:
+ delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
+ position_ids = position_ids + delta.to(position_ids.device)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ output = Qwen2_5_VLModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Qwen2_5_VL causal language model (or autoregressive) outputs.
+ """
+)
+class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^visual": "model.visual",
+ r"^model(?!\.(language_model|visual))": "model.language_model",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Qwen2_5_VLModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw)
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ return self.model.get_image_features(pixel_values, image_grid_thw)
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def visual(self):
+ return self.model.visual
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ second_per_grid_ts: Optional[torch.Tensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Qwen2_5_VLCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
+
+ >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": "What is shown in this image?"},
+ ],
+ },
+ ]
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return Qwen2_5_VLCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ second_per_grid_ts=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ # Qwen2-5-VL position_ids are prepared with rope_deltas
+ if position_ids is None:
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # When compiling, we can't check tensor values thus we check only input length
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
+ # models currently cannot do assisted decoding
+ if cache_position[0] == 0 or self.model.rope_deltas is None:
+ vision_positions, rope_deltas = self.model.get_rope_index(
+ model_inputs.get("input_ids", None),
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ attention_mask=attention_mask,
+ )
+ self.model.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ elif "position_ids" in model_inputs:
+ batch_size, seq_length = model_inputs["position_ids"].shape
+ device = model_inputs["position_ids"].device
+ position_ids = torch.arange(seq_length, device=device)
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
+ delta = cache_position[0] + self.model.rope_deltas
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ vision_positions = position_ids + delta.expand_as(position_ids)
+
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
+ text_positions = model_inputs["position_ids"][None, ...]
+ model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
+
+ if cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_values_videos"] = None
+
+ return model_inputs
+
+ def _get_image_nums_and_video_nums(
+ self,
+ input_ids: Optional[torch.LongTensor],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Returns:
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
+ """
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+
+ if inputs_embeds is not None:
+ vision_start_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ image_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ video_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ else:
+ vision_start_mask = input_ids == vision_start_token_id
+ image_mask = input_ids == image_token_id
+ video_mask = input_ids == video_token_id
+
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
+
+ return image_nums, video_nums
+
+ def _expand_inputs_for_generation(
+ self,
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: Optional[torch.LongTensor] = None,
+ **model_kwargs,
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
+ # Overwritten -- Support for expanding tensors without a batch size dimension
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
+ # image_grid_thw.shape[0] is sum(num_images for samples)
+
+ if expand_size == 1:
+ return input_ids, model_kwargs
+
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
+
+ def _expand_dict_for_generation_visual(dict_to_expand):
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
+ )
+
+ def _repeat_interleave_samples(x, lengths, repeat_times):
+ samples = torch.split(x, lengths)
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
+ return result
+
+ for key in dict_to_expand:
+ if key == "pixel_values":
+ # split images into samples
+ samples = torch.split(image_grid_thw, list(image_nums))
+ # compute the sequence length of images for each sample
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "image_grid_thw":
+ # get the num of images for each sample
+ lengths = list(image_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "pixel_values_videos":
+ samples = torch.split(video_grid_thw, list(video_nums))
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "video_grid_thw":
+ lengths = list(video_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "second_per_grid_ts":
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
+ )
+ return dict_to_expand
+
+ def _expand_dict_for_generation(dict_to_expand):
+ for key in dict_to_expand:
+ if (
+ key != "cache_position"
+ and dict_to_expand[key] is not None
+ and isinstance(dict_to_expand[key], torch.Tensor)
+ and key not in visual_keys
+ ):
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
+ return dict_to_expand
+
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
+
+ if is_encoder_decoder:
+ if model_kwargs.get("encoder_outputs") is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
+
+ return input_ids, model_kwargs
+
+
+__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLTextModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2eac303213c99b067856861fbde46bb186213c1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py
@@ -0,0 +1,1040 @@
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2.5-VL model."""
+
+from typing import Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig
+from transformers.models.qwen2_vl.modeling_qwen2_vl import (
+ PatchEmbed,
+ PatchMerger,
+ Qwen2RMSNorm,
+ Qwen2VLCausalLMOutputWithPast,
+ Qwen2VLForConditionalGeneration,
+ Qwen2VLModel,
+ Qwen2VLModelOutputWithPast,
+ Qwen2VLPreTrainedModel,
+ TransformersKwargs,
+ VisionAttention,
+ VisionRotaryEmbedding,
+)
+from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLImagesKwargs, Qwen2VLProcessor
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...configuration_utils import PretrainedConfig
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...modeling_flash_attention_utils import is_flash_attn_available
+from ...modeling_layers import GradientCheckpointingLayer
+from ...processing_utils import MultiModalData, ProcessingKwargs, Unpack, VideosKwargs
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import is_torchdynamo_compiling, logging
+from ...video_utils import VideoInput
+
+
+if is_flash_attn_available():
+ pass
+
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen2_5_VLVisionConfig(PretrainedConfig):
+ model_type = "qwen2_5_vl"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=32,
+ hidden_size=3584,
+ hidden_act="silu",
+ intermediate_size=3420,
+ num_heads=16,
+ in_channels=3,
+ patch_size=14,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ tokens_per_second=4,
+ window_size=112,
+ out_hidden_size=3584,
+ fullatt_block_indexes=[7, 15, 23, 31],
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.tokens_per_second = tokens_per_second
+ self.window_size = window_size
+ self.fullatt_block_indexes = fullatt_block_indexes
+ self.out_hidden_size = out_hidden_size
+ self.initializer_range = initializer_range
+
+
+class Qwen2_5_VLTextConfig(Qwen2VLTextConfig):
+ model_type = "qwen2_5_vl_text"
+
+
+class Qwen2_5_VLConfig(Qwen2VLConfig):
+ model_type = "qwen2_5_vl"
+ sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig}
+
+
+class Qwen2_5_VLMLP(nn.Module):
+ def __init__(self, config, bias: bool = False):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+class Qwen2_5_VisionPatchEmbed(PatchEmbed):
+ pass
+
+
+class Qwen2_5_VisionRotaryEmbedding(VisionRotaryEmbedding):
+ pass
+
+
+class Qwen2_5_VLPatchMerger(PatchMerger):
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
+ super().__init__(dim, context_dim, spatial_merge_size)
+ self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
+
+
+class Qwen2_5_VLVisionAttention(VisionAttention):
+ def __init__(self, config: Qwen2_5_VLVisionConfig) -> None:
+ super().__init__(config)
+ self.dim = config.hidden_size
+
+
+class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
+ super().__init__()
+ self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
+ self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
+ self.attn = Qwen2_5_VLVisionAttention(config=config)
+ self.mlp = Qwen2_5_VLMLP(config, bias=True)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel):
+ pass
+
+
+class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
+ config: Qwen2_5_VLVisionConfig
+ _no_split_modules = ["Qwen2_5_VLVisionBlock"]
+
+ def __init__(self, config, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.spatial_merge_size = config.spatial_merge_size
+ self.patch_size = config.patch_size
+ self.fullatt_block_indexes = config.fullatt_block_indexes
+ self.window_size = config.window_size
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
+
+ self.patch_embed = Qwen2_5_VisionPatchEmbed(
+ patch_size=config.patch_size,
+ temporal_patch_size=config.temporal_patch_size,
+ in_channels=config.in_channels,
+ embed_dim=config.hidden_size,
+ )
+
+ head_dim = config.hidden_size // config.num_heads
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList([Qwen2_5_VLVisionBlock(config) for _ in range(config.depth)])
+ self.merger = Qwen2_5_VLPatchMerger(
+ dim=config.out_hidden_size,
+ context_dim=config.hidden_size,
+ spatial_merge_size=config.spatial_merge_size,
+ )
+ self.gradient_checkpointing = False
+
+ def rot_pos_emb(self, grid_thw):
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb
+
+ def get_window_index(self, grid_thw):
+ window_index: list = []
+ cu_window_seqlens: list = [0]
+ window_index_id = 0
+ vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
+
+ for grid_t, grid_h, grid_w in grid_thw:
+ llm_grid_h, llm_grid_w = (
+ grid_h // self.spatial_merge_size,
+ grid_w // self.spatial_merge_size,
+ )
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
+ index_padded = index_padded.reshape(
+ grid_t,
+ num_windows_h,
+ vit_merger_window_size,
+ num_windows_w,
+ vit_merger_window_size,
+ )
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
+ grid_t,
+ num_windows_h * num_windows_w,
+ vit_merger_window_size,
+ vit_merger_window_size,
+ )
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
+ index_padded = index_padded.reshape(-1)
+ index_new = index_padded[index_padded != -100]
+ window_index.append(index_new + window_index_id)
+ cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
+ window_index = torch.cat(window_index, dim=0)
+
+ return window_index, cu_window_seqlens
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
+ The final hidden states of the model.
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Returns:
+ `torch.Tensor`: hidden_states.
+ """
+ hidden_states = self.patch_embed(hidden_states)
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
+ cu_window_seqlens = torch.tensor(
+ cu_window_seqlens,
+ device=hidden_states.device,
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
+
+ seq_len, _ = hidden_states.size()
+ hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ hidden_states = hidden_states[window_index, :, :]
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ for layer_num, blk in enumerate(self.blocks):
+ if layer_num in self.fullatt_block_indexes:
+ cu_seqlens_now = cu_seqlens
+ else:
+ cu_seqlens_now = cu_window_seqlens
+
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens_now,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.merger(hidden_states)
+ reverse_indices = torch.argsort(window_index)
+ hidden_states = hidden_states[reverse_indices, :]
+
+ return hidden_states
+
+
+class Qwen2_5_VLModelOutputWithPast(Qwen2VLModelOutputWithPast):
+ pass
+
+
+class Qwen2_5_VLModel(Qwen2VLModel):
+ config: Qwen2_5_VLConfig
+ base_model_prefix = ""
+ _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
+
+ def get_rope_index(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ second_per_grid_ts: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
+
+ Explanation:
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
+
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
+ Examples:
+ input_ids: [T T T T T], here T is for text.
+ temporal position_ids: [0, 1, 2, 3, 4]
+ height position_ids: [0, 1, 2, 3, 4]
+ width position_ids: [0, 1, 2, 3, 4]
+
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
+ and 1D rotary position embedding for text part.
+ Examples:
+ Temporal (Time): 3 patches, representing different segments of the video in time.
+ Height: 2 patches, dividing each frame vertically.
+ Width: 2 patches, dividing each frame horizontally.
+ We also have some important parameters:
+ fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
+ tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
+ temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
+ interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
+ vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
+ text temporal position_ids: [101, 102, 103, 104, 105]
+ text height position_ids: [101, 102, 103, 104, 105]
+ text width position_ids: [101, 102, 103, 104, 105]
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ Returns:
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
+ """
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is not None:
+ attention_mask = attention_mask == 1
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_index, video_index = 0, 0
+ for i, input_ids in enumerate(total_input_ids):
+ if attention_mask is not None:
+ input_ids = input_ids[attention_mask[i]]
+ image_nums, video_nums = 0, 0
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ second_per_grid_t = 0
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ if second_per_grid_ts is not None:
+ second_per_grid_t = second_per_grid_ts[video_index]
+ else:
+ second_per_grid_t = 1.0
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+ text_len = ed - st
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ range_tensor = torch.arange(llm_grid_t).view(-1, 1)
+ expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
+
+ ## normalize type, send to device.
+ second_per_grid_t = torch.as_tensor(
+ second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device
+ )
+
+ time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
+
+ time_tensor_long = time_tensor.long()
+ t_index = time_tensor_long.flatten()
+
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ if attention_mask is not None:
+ position_ids[..., i, attention_mask[i]] = llm_positions.to(position_ids.device)
+ else:
+ position_ids[..., i, :] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=input_ids.device)
+ return position_ids, mrope_position_deltas
+ else:
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
+ else:
+ position_ids = (
+ torch.arange(input_ids.shape[1], device=input_ids.device)
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ second_per_grid_ts: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Qwen2_5_VLModelOutputWithPast]:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw)
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ _, video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ if position_ids is None:
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # When compiling, we can't check tensor values thus we check only input length
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
+ # models currently cannot do asssisted decoding
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
+ (input_ids is not None and input_ids.shape[1] != 1)
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
+ )
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
+ (cache_position is not None and cache_position[0] == 0)
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
+ )
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ attention_mask=attention_mask,
+ )
+ self.rope_deltas = rope_deltas
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
+ if cache_position is not None:
+ delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
+ else:
+ delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
+ position_ids = position_ids + delta.to(position_ids.device)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ output = Qwen2_5_VLModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+ return output if return_dict else output.to_tuple()
+
+
+class Qwen2_5_VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast):
+ pass
+
+
+class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ second_per_grid_ts: Optional[torch.Tensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Qwen2_5_VLCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
+
+ >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": "What is shown in this image?"},
+ ],
+ },
+ ]
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return Qwen2_5_VLCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ second_per_grid_ts=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ # Qwen2-5-VL position_ids are prepared with rope_deltas
+ if position_ids is None:
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # When compiling, we can't check tensor values thus we check only input length
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
+ # models currently cannot do assisted decoding
+ if cache_position[0] == 0 or self.model.rope_deltas is None:
+ vision_positions, rope_deltas = self.model.get_rope_index(
+ model_inputs.get("input_ids", None),
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ attention_mask=attention_mask,
+ )
+ self.model.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ elif "position_ids" in model_inputs:
+ batch_size, seq_length = model_inputs["position_ids"].shape
+ device = model_inputs["position_ids"].device
+ position_ids = torch.arange(seq_length, device=device)
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
+ delta = cache_position[0] + self.model.rope_deltas
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ vision_positions = position_ids + delta.expand_as(position_ids)
+
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
+ text_positions = model_inputs["position_ids"][None, ...]
+ model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
+
+ if cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_values_videos"] = None
+
+ return model_inputs
+
+
+class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
+ fps: Union[list[float], float]
+
+
+class Qwen2_5_VLImagesKwargs(Qwen2VLImagesKwargs):
+ pass
+
+
+class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs: Qwen2_5_VLImagesKwargs
+ videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "return_mm_token_type_ids": False,
+ },
+ }
+
+
+class Qwen2_5_VLProcessor(Qwen2VLProcessor):
+ r"""
+ Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
+ [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
+ [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
+ Args:
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ video_processor ([`Qwen2_5_VLVideoProcessor`], *optional*):
+ The video processor is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ image_processor_class = "AutoImageProcessor"
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+ return names_from_processor + ["second_per_grid_ts"]
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ videos: Optional[VideoInput] = None,
+ **kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwargs` arguments to
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
+ tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
+ - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
+ - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ Qwen2_5_VLProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ image_inputs = videos_inputs = {}
+ if images is not None:
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ image_grid_thw = image_inputs["image_grid_thw"]
+
+ if videos is not None:
+ fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
+ videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
+ video_grid_thw = videos_inputs["video_grid_thw"]
+
+ if isinstance(fps, (int, float)):
+ second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
+ elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
+ second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps]
+ else:
+ raise ValueError(
+ f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
+ )
+ videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
+
+ if not isinstance(text, list):
+ text = [text]
+
+ text = text.copy() # below lines change text in-place
+ if images is not None:
+ merge_length = self.image_processor.merge_size**2
+ index = 0
+ for i in range(len(text)):
+ while self.image_token in text[i]:
+ num_image_tokens = image_grid_thw[index].prod() // merge_length
+ text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
+
+ if videos is not None:
+ merge_length = self.video_processor.merge_size**2
+ index = 0
+ for i in range(len(text)):
+ while self.video_token in text[i]:
+ num_video_tokens = video_grid_thw[index].prod() // merge_length
+ text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1)
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", self.video_token)
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+ Args:
+ image_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (height, width) per each image.
+ video_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (num_frames, height, width) per each video.
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("images_kwargs", {})
+ images_kwargs.update(kwargs)
+ merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
+
+ num_image_patches = [
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
+ for image_size in image_sizes
+ ]
+ num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+
+ if video_sizes is not None:
+ videos_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("videos_kwargs", {})
+ videos_kwargs.update(kwargs)
+ num_video_patches = [
+ self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
+ for video_size in video_sizes
+ ]
+ num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
+ vision_data["num_video_tokens"] = num_video_tokens
+
+ return MultiModalData(**vision_data)
+
+
+__all__ = [
+ "Qwen2_5_VLConfig",
+ "Qwen2_5_VLTextConfig",
+ "Qwen2_5_VLForConditionalGeneration",
+ "Qwen2_5_VLModel",
+ "Qwen2_5_VLPreTrainedModel",
+ "Qwen2_5_VLProcessor",
+ "Qwen2_5_VLTextModel", # noqa: F822
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..b357ba850debac6afbcff27f349d47083529d750
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
@@ -0,0 +1,278 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen2_5_vl.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...video_utils import VideoInput
+
+
+class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
+ fps: Union[list[float], float]
+
+
+class Qwen2_5_VLImagesKwargs(ImagesKwargs):
+ min_pixels: Optional[int]
+ max_pixels: Optional[int]
+ patch_size: Optional[int]
+ temporal_patch_size: Optional[int]
+ merge_size: Optional[int]
+
+
+class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs: Qwen2_5_VLImagesKwargs
+ videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "return_mm_token_type_ids": False,
+ },
+ }
+
+
+class Qwen2_5_VLProcessor(ProcessorMixin):
+ r"""
+ Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
+ [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
+ [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
+ Args:
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ video_processor ([`Qwen2_5_VLVideoProcessor`], *optional*):
+ The video processor is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ attributes = ["image_processor", "tokenizer", "video_processor"]
+
+ image_processor_class = "AutoImageProcessor"
+ video_processor_class = "AutoVideoProcessor"
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
+
+ def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
+ self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
+ self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
+ self.image_token_id = (
+ tokenizer.image_token_id
+ if getattr(tokenizer, "image_token_id", None)
+ else tokenizer.convert_tokens_to_ids(self.image_token)
+ )
+ self.video_token_id = (
+ tokenizer.video_token_id
+ if getattr(tokenizer, "video_token_id", None)
+ else tokenizer.convert_tokens_to_ids(self.video_token)
+ )
+ super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ videos: Optional[VideoInput] = None,
+ **kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwargs` arguments to
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
+ tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
+ - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
+ - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ Qwen2_5_VLProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ image_inputs = videos_inputs = {}
+ if images is not None:
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ image_grid_thw = image_inputs["image_grid_thw"]
+
+ if videos is not None:
+ fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
+ videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
+ video_grid_thw = videos_inputs["video_grid_thw"]
+
+ if isinstance(fps, (int, float)):
+ second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
+ elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
+ second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps]
+ else:
+ raise ValueError(
+ f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
+ )
+ videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
+
+ if not isinstance(text, list):
+ text = [text]
+
+ text = text.copy() # below lines change text in-place
+ if images is not None:
+ merge_length = self.image_processor.merge_size**2
+ index = 0
+ for i in range(len(text)):
+ while self.image_token in text[i]:
+ num_image_tokens = image_grid_thw[index].prod() // merge_length
+ text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
+
+ if videos is not None:
+ merge_length = self.video_processor.merge_size**2
+ index = 0
+ for i in range(len(text)):
+ while self.video_token in text[i]:
+ num_video_tokens = video_grid_thw[index].prod() // merge_length
+ text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1)
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", self.video_token)
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+ Args:
+ image_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (height, width) per each image.
+ video_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (num_frames, height, width) per each video.
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("images_kwargs", {})
+ images_kwargs.update(kwargs)
+ merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
+
+ num_image_patches = [
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
+ for image_size in image_sizes
+ ]
+ num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+
+ if video_sizes is not None:
+ videos_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("videos_kwargs", {})
+ videos_kwargs.update(kwargs)
+ num_video_patches = [
+ self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
+ for video_size in video_sizes
+ ]
+ num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
+ vision_data["num_video_tokens"] = num_video_tokens
+
+ return MultiModalData(**vision_data)
+
+ def post_process_image_text_to_text(
+ self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
+ ):
+ """
+ Post-process the output of the model to decode the text.
+
+ Args:
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
+ or `(sequence_length,)`.
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+ Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
+ **kwargs:
+ Additional arguments to be passed to the tokenizer's `batch_decode method`.
+
+ Returns:
+ `list[str]`: The decoded text.
+ """
+ return self.tokenizer.batch_decode(
+ generated_outputs,
+ skip_special_tokens=skip_special_tokens,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+ return names_from_processor + ["second_per_grid_ts"]
+
+
+__all__ = ["Qwen2_5_VLProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4000cb272723d9920136a6c78465e8413a8b4d1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_qwen3_vl_moe import *
+ from .modeling_qwen3_vl_moe import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..25358aa79bff482437632829f6319effad36f138
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py
@@ -0,0 +1,335 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen3_vl_moe.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class Qwen3VLMoeTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a
+ Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 151936):
+ Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Qwen2MoeModel`]
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 5632):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 24):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 16):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 128000):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 5000000.0):
+ The base period of the RoPE embeddings.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ decoder_sparse_step (`int`, *optional*, defaults to 1):
+ The frequency of the MoE layer.
+ moe_intermediate_size (`int`, *optional*, defaults to 1408):
+ Intermediate size of the routed expert.
+ num_experts_per_tok (`int`, *optional*, defaults to 4):
+ Number of selected experts.
+ num_experts (`int`, *optional*, defaults to 60):
+ Number of routed experts.
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the topk probabilities.
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
+ The aux loss factor for the total loss.
+ mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
+ Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock
+ The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
+ If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ head_dim (`int`, *optional*):
+ The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
+
+ ```python
+ >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
+
+ >>> # Initializing a Qwen3VLMoe style configuration
+ >>> configuration = Qwen3VLMoeConfig()
+
+ >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
+ >>> model = Qwen3VLMoeForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_vl_moe_text"
+ base_config_key = "text_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `Qwen3VLMoe`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=151936,
+ hidden_size=2048,
+ intermediate_size=5632,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ hidden_act="silu",
+ max_position_embeddings=128000,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=5000000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ decoder_sparse_step=1,
+ moe_intermediate_size=1408,
+ num_experts_per_tok=4,
+ num_experts=60,
+ norm_topk_prob=True,
+ router_aux_loss_coef=0.001,
+ mlp_only_layers=None,
+ rope_scaling=None,
+ head_dim=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.rope_scaling = rope_scaling
+ self.head_dim = head_dim or hidden_size // num_attention_heads
+
+ rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
+
+ # MoE arguments
+ self.decoder_sparse_step = decoder_sparse_step
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_experts = num_experts
+ self.norm_topk_prob = norm_topk_prob
+ self.router_aux_loss_coef = router_aux_loss_coef
+ self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+class Qwen3VLMoeVisionConfig(PretrainedConfig):
+ model_type = "qwen3_vl_moe"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=27,
+ hidden_size=1152,
+ hidden_act="gelu_pytorch_tanh",
+ intermediate_size=4304,
+ num_heads=16,
+ in_channels=3,
+ patch_size=16,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ out_hidden_size=3584,
+ num_position_embeddings=2304,
+ deepstack_visual_indexes=[8, 16, 24],
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.out_hidden_size = out_hidden_size
+ self.num_position_embeddings = num_position_embeddings
+ self.initializer_range = initializer_range
+ self.deepstack_visual_indexes = deepstack_visual_indexes
+
+
+class Qwen3VLMoeConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a
+ Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 151655):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 151656):
+ The video token index to encode the image prompt.
+ vision_start_token_id (`int`, *optional*, defaults to 151652):
+ The start token index to encode the image prompt.
+ vision_end_token_id (`int`, *optional*, defaults to 151653):
+ The end token index to encode the image prompt.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie the word embeddings.
+
+ ```python
+ >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
+
+ >>> # Initializing a Qwen3-VL-MOE style configuration
+ >>> configuration = Qwen3VLMoeConfig()
+
+ >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
+ >>> model = Qwen3VLMoeForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_vl_moe"
+ sub_configs = {"vision_config": Qwen3VLMoeVisionConfig, "text_config": Qwen3VLMoeTextConfig}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=151655,
+ video_token_id=151656,
+ vision_start_token_id=151652,
+ vision_end_token_id=151653,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ if isinstance(vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+
+ if isinstance(text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ self.text_config = self.sub_configs["text_config"]()
+
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+ super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
+
+
+__all__ = ["Qwen3VLMoeConfig", "Qwen3VLMoeTextConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..88e3f6e19f0eadbda8853b5c3b8f39fc7e61e580
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py
@@ -0,0 +1,1832 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen3_vl_moe.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import OutputRecorder, check_model_inputs
+from .configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig, Qwen3VLMoeVisionConfig
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class Qwen3VLMoeTextRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Qwen3VLMoeTextRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Qwen3VLMoeTextExperts(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_experts = config.num_experts
+ self.intermediate_size = config.moe_intermediate_size
+ self.hidden_size = config.hidden_size
+ self.expert_dim = self.intermediate_size
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
+ self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(
+ self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ When training it is more efficient to just loop over the experts and compute the output for each expert
+ as otherwise the memory would explode.
+
+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
+
+ Args:
+ hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
+ router_indices (torch.Tensor): (batch_size * token_num, top_k)
+ Returns:
+ torch.Tensor
+ """
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
+ if self.training:
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+ with torch.no_grad():
+ expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts)
+ expert_mask = expert_mask.permute(2, 1, 0)
+ # we sum on the top_k and on the sequence length to get which experts
+ # are hit this time around
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+ for expert_idx in expert_hit[:]:
+ with torch.no_grad():
+ _, token_idx = torch.where(expert_mask[expert_idx[0]])
+ current_state = hidden_states[token_idx]
+ gate_up = current_state @ self.gate_up_proj[expert_idx]
+ gate, up = gate_up.chunk(2, dim=-1)
+ gated_output = up * self.act_fn(gate)
+ out = gated_output @ self.down_proj[expert_idx]
+ weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
+ else:
+ hidden_states = hidden_states.repeat(self.num_experts, 1)
+ hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
+ gate_up = torch.bmm(hidden_states, self.gate_up_proj)
+ gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
+ next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
+ next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size)
+ next_states = (
+ next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None]
+ )
+ next_states = next_states.sum(dim=0)
+ return next_states
+
+
+class Qwen3VLMoeTextSparseMoeBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_experts = config.num_experts
+ self.top_k = config.num_experts_per_tok
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
+ self.experts = Qwen3VLMoeTextExperts(config)
+
+ # since all the models use norm_topk_prob, we don't need to have a extra check for it
+ # self.norm_topk_prob = config.norm_topk_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states.reshape(-1, self.hidden_size)
+ router_logits = self.gate(hidden_states)
+ routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
+ routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
+ routing_weights = routing_weights.to(hidden_states.dtype)
+ router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights)
+ hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size)
+ routed_out = self.experts(hidden_states, router_weights, router_indices)
+ return routed_out, router_logits
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class Qwen3VLMoeTextAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.q_norm = Qwen3VLMoeTextRMSNorm(
+ self.head_dim, eps=config.rms_norm_eps
+ ) # unlike olmo, only on the head dim!
+ self.k_norm = Qwen3VLMoeTextRMSNorm(
+ self.head_dim, eps=config.rms_norm_eps
+ ) # thus post q_norm does not need reshape
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Qwen3VLMoeTextMLP(nn.Module):
+ def __init__(self, config, intermediate_size=None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class Qwen3VLMoeTextDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = Qwen3VLMoeTextAttention(config, layer_idx)
+
+ if (layer_idx not in config.mlp_only_layers) and (
+ config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
+ ):
+ self.mlp = Qwen3VLMoeTextSparseMoeBlock(config)
+ else:
+ self.mlp = Qwen3VLMoeTextMLP(config, intermediate_size=config.intermediate_size)
+
+ self.input_layernorm = Qwen3VLMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen3VLMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> torch.FloatTensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_router_logits (`bool`, *optional*):
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss,
+ and should not be returned during inference.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ # For the MoE layers, we need to unpack
+ if isinstance(hidden_states, tuple):
+ hidden_states, _ = hidden_states
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+@auto_docstring
+class Qwen3VLMoePreTrainedModel(PreTrainedModel):
+ config: Qwen3VLMoeConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLMoeVisionBlock"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "router_logits": OutputRecorder(Qwen3VLMoeTextSparseMoeBlock, index=1),
+ "hidden_states": Qwen3VLMoeTextDecoderLayer,
+ "attentions": Qwen3VLMoeTextAttention,
+ }
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ super()._init_weights(module)
+ if hasattr(self.config, "initializer_range"):
+ std = self.config.initializer_range
+ else:
+ std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
+ if isinstance(module, Qwen3VLMoeTextExperts):
+ module.gate_up_proj.data.normal_(mean=0.0, std=std)
+ module.down_proj.data.normal_(mean=0.0, std=std)
+
+
+class Qwen3VLMoeVisionMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
+ self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
+
+
+class Qwen3VLMoeVisionPatchEmbed(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.patch_size = config.patch_size
+ self.temporal_patch_size = config.temporal_patch_size
+ self.in_channels = config.in_channels
+ self.embed_dim = config.hidden_size
+
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
+ self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class Qwen3VLMoeVisionRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+class Qwen3VLMoeVisionPatchMerger(nn.Module):
+ def __init__(self, config: Qwen3VLMoeVisionConfig, use_postshuffle_norm=False) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
+ self.use_postshuffle_norm = use_postshuffle_norm
+ self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
+ self.act_fn = nn.GELU()
+ self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
+ x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
+ return x
+
+
+def apply_rotary_pos_emb_vision(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
+ orig_q_dtype = q.dtype
+ orig_k_dtype = k.dtype
+ q, k = q.float(), k.float()
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ q_embed = q_embed.to(orig_q_dtype)
+ k_embed = k_embed.to(orig_k_dtype)
+ return q_embed, k_embed
+
+
+class Qwen3VLMoeVisionAttention(nn.Module):
+ def __init__(self, config: Qwen3VLMoeVisionConfig) -> None:
+ super().__init__()
+ self.dim = config.hidden_size
+ self.num_heads = config.num_heads
+ self.head_dim = self.dim // self.num_heads
+ self.num_key_value_groups = 1 # needed for eager attention
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
+ self.proj = nn.Linear(self.dim, self.dim)
+ self.scaling = self.head_dim**-0.5
+ self.config = config
+ self.attention_dropout = 0.0
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ query_states, key_states, value_states = (
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ )
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
+
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ if self.config._attn_implementation == "flash_attention_2":
+ # Flash Attention 2: Use cu_seqlens for variable length attention
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+ attn_output, _ = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ cu_seq_lens_q=cu_seqlens,
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+ else:
+ # Other implementations: Process each chunk separately
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+ splits = [
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
+ ]
+
+ attn_outputs = [
+ attention_interface(
+ self,
+ q,
+ k,
+ v,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ is_causal=False,
+ **kwargs,
+ )[0]
+ for q, k, v in zip(*splits)
+ ]
+ attn_output = torch.cat(attn_outputs, dim=1)
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Qwen3VLMoeVisionBlock(GradientCheckpointingLayer):
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
+ super().__init__()
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ self.attn = Qwen3VLMoeVisionAttention(config=config)
+ self.mlp = Qwen3VLMoeVisionMLP(config=config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Qwen3VLMoeVisionModel(Qwen3VLMoePreTrainedModel):
+ config: Qwen3VLMoeVisionConfig
+ _no_split_modules = ["Qwen3VLMoeVisionBlock"]
+
+ def __init__(self, config, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.spatial_merge_size = config.spatial_merge_size
+ self.patch_size = config.patch_size
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
+
+ self.patch_embed = Qwen3VLMoeVisionPatchEmbed(
+ config=config,
+ )
+
+ self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
+
+ head_dim = config.hidden_size // config.num_heads
+ self.rotary_pos_emb = Qwen3VLMoeVisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList([Qwen3VLMoeVisionBlock(config) for _ in range(config.depth)])
+ self.merger = Qwen3VLMoeVisionPatchMerger(
+ config=config,
+ use_postshuffle_norm=False,
+ )
+
+ self.deepstack_visual_indexes = config.deepstack_visual_indexes
+ self.deepstack_merger_list = nn.ModuleList(
+ [
+ Qwen3VLMoeVisionPatchMerger(
+ config=config,
+ use_postshuffle_norm=True,
+ )
+ for _ in range(len(config.deepstack_visual_indexes))
+ ]
+ )
+
+ self.gradient_checkpointing = False
+
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
+ merge_size = self.spatial_merge_size
+
+ max_hw = int(grid_thw[:, 1:].max().item())
+ freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
+ device = freq_table.device
+
+ total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
+ pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
+
+ offset = 0
+ for num_frames, height, width in grid_thw:
+ merged_h, merged_w = height // merge_size, width // merge_size
+
+ block_rows = torch.arange(merged_h, device=device) # block row indices
+ block_cols = torch.arange(merged_w, device=device) # block col indices
+ intra_row = torch.arange(merge_size, device=device) # intra-block row offsets
+ intra_col = torch.arange(merge_size, device=device) # intra-block col offsets
+
+ # Compute full-resolution positions
+ row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
+ col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
+
+ row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+ col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+
+ coords = torch.stack((row_idx, col_idx), dim=-1)
+
+ if num_frames > 1:
+ coords = coords.repeat(num_frames, 1)
+
+ num_tokens = coords.shape[0]
+ pos_ids[offset : offset + num_tokens] = coords
+ offset += num_tokens
+
+ embeddings = freq_table[pos_ids] # lookup rotary embeddings
+ embeddings = embeddings.flatten(1)
+ return embeddings
+
+ def fast_pos_embed_interpolate(self, grid_thw):
+ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
+
+ idx_list = [[] for _ in range(4)]
+ weight_list = [[] for _ in range(4)]
+
+ for t, h, w in zip(grid_ts, grid_hs, grid_ws):
+ h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
+ w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
+
+ h_idxs_floor = h_idxs.int()
+ w_idxs_floor = w_idxs.int()
+ h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+ w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+
+ dh = h_idxs - h_idxs_floor
+ dw = w_idxs - w_idxs_floor
+
+ base_h = h_idxs_floor * self.num_grid_per_side
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
+
+ indices = [
+ (base_h[None].T + w_idxs_floor[None]).flatten(),
+ (base_h[None].T + w_idxs_ceil[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
+ ]
+
+ weights = [
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
+ ((1 - dh)[None].T * dw[None]).flatten(),
+ (dh[None].T * (1 - dw)[None]).flatten(),
+ (dh[None].T * dw[None]).flatten(),
+ ]
+
+ for i in range(4):
+ idx_list[i].extend(indices[i].tolist())
+ weight_list[i].extend(weights[i].tolist())
+
+ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
+ weight_tensor = torch.tensor(
+ weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
+ )
+ pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
+
+ patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
+
+ patch_pos_embeds_permute = []
+ merge_size = self.config.spatial_merge_size
+ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
+ pos_embed = pos_embed.repeat(t, 1)
+ pos_embed = (
+ pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
+ .permute(0, 1, 3, 2, 4, 5)
+ .flatten(0, 4)
+ )
+ patch_pos_embeds_permute.append(pos_embed)
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
+ return patch_pos_embeds
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
+ The final hidden states of the model.
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Returns:
+ `torch.Tensor`: hidden_states.
+ """
+ hidden_states = self.patch_embed(hidden_states)
+
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
+ hidden_states = hidden_states + pos_embeds
+
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+
+ seq_len, _ = hidden_states.size()
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ deepstack_feature_lists = []
+ for layer_num, blk in enumerate(self.blocks):
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ if layer_num in self.deepstack_visual_indexes:
+ deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
+ hidden_states
+ )
+ deepstack_feature_lists.append(deepstack_feature)
+
+ hidden_states = self.merger(hidden_states)
+
+ return hidden_states, deepstack_feature_lists
+
+
+class Qwen3VLMoeTextRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Qwen3VLMoeTextConfig, device=None):
+ super().__init__()
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", "default")
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
+
+ def apply_interleaved_mrope(self, freqs, mrope_section):
+ """Apply interleaved MRoPE to 3D rotary embeddings.
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
+ args:
+ x: (3, bs, seq_len, head_dim // 2)
+ mrope_section: (3,)
+ returns:
+ x_t: (bs, seq_len, head_dim // 2)
+ """
+ freqs_t = freqs[0] # just overwrite the first dimension T
+ for dim, offset in enumerate((1, 2), start=1): # H, W
+ length = mrope_section[dim] * 3
+ idx = slice(offset, length, 3)
+ freqs_t[..., idx] = freqs[dim, ..., idx]
+ return freqs_t
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, Qwen3VLMoe has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ if position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@auto_docstring(
+ custom_intro=(
+ "Text part of Qwen3VLMoe, "
+ "not a pure text-only model, as DeepStack integrates visual features into the early hidden states."
+ )
+)
+class Qwen3VLMoeTextModel(Qwen3VLMoePreTrainedModel):
+ config: Qwen3VLMoeTextConfig
+ _no_split_modules = ["Qwen3VLMoeTextDecoderLayer"]
+
+ def __init__(self, config: Qwen3VLMoeTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Qwen3VLMoeTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Qwen3VLMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen3VLMoeTextRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ # args for deepstack
+ visual_pos_masks: Optional[torch.Tensor] = None,
+ deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ r"""
+ visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*):
+ The mask of the visual positions.
+ deepstack_visual_embeds (`list[torch.Tensor]`, *optional*):
+ The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim).
+ The feature is extracted from the different visual encoder layers, and fed to the decoder
+ hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334).
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ # torch.jit.trace() doesn't support cache objects in the output
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # the hard coded `3` is for temporal, height and width.
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
+ text_position_ids = position_ids[0]
+ position_ids = position_ids[1:]
+ else:
+ text_position_ids = position_ids[0]
+
+ attention_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=text_position_ids,
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ for layer_idx, decoder_layer in enumerate(self.layers):
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=text_position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = layer_outputs
+
+ # add visual features to the hidden states of first several layers
+ if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):
+ hidden_states = self._deepstack_process(
+ hidden_states,
+ visual_pos_masks,
+ deepstack_visual_embeds[layer_idx],
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+ def _deepstack_process(
+ self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor
+ ):
+ visual_pos_masks = visual_pos_masks.to(hidden_states.device)
+ visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
+ local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds
+ hidden_states[visual_pos_masks, :] = local_this
+ return hidden_states
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Qwen3VLMoe causal language model (or autoregressive) outputs.
+ """
+)
+class Qwen3VLMoeCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+ aux_loss: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class Qwen3VLMoeModelOutputWithPast(ModelOutput):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+@auto_docstring
+class Qwen3VLMoeModel(Qwen3VLMoePreTrainedModel):
+ base_model_prefix = ""
+ _checkpoint_conversion_mapping = {}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: Qwen3VLMoeConfig
+ _no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLMoeVisionBlock"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = Qwen3VLMoeVisionModel._from_config(config.vision_config)
+ self.language_model = Qwen3VLMoeTextModel._from_config(config.text_config)
+ self.rope_deltas = None # cache rope_deltas here
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_rope_index(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Different from the original implementation, Qwen3VLMoe use timestamps rather than absolute time position ids."""
+
+ # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split
+ if video_grid_thw is not None:
+ video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
+ video_grid_thw[:, 0] = 1
+
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_index, video_index = 0, 0
+ attention_mask = attention_mask.to(total_input_ids.device)
+ for i, input_ids in enumerate(total_input_ids):
+ input_ids = input_ids[attention_mask[i] == 1]
+ image_nums, video_nums = 0, 0
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+ text_len = ed - st
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos)
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+ else:
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
+ else:
+ position_ids = (
+ torch.arange(input_ids.shape[1], device=input_ids.device)
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ """
+ Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.
+
+ Args:
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ # Same implementation as for images
+ return self.get_image_features(pixel_values_videos, video_grid_thw)
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ pixel_values = pixel_values.type(self.visual.dtype)
+ image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ image_embeds = torch.split(image_embeds, split_sizes)
+ return image_embeds, deepstack_image_embeds
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: Optional[torch.FloatTensor] = None,
+ video_features: Optional[torch.FloatTensor] = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.video_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
+ raise ValueError(
+ f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
+ )
+
+ return special_image_mask, special_video_mask
+
+ @auto_docstring
+ @check_model_inputs
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Qwen3VLMoeModelOutputWithPast]:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ image_mask = None
+ video_mask = None
+
+ if pixel_values is not None:
+ image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw)
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ _, video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ visual_pos_masks = None
+ deepstack_visual_embeds = None
+ if image_mask is not None and video_mask is not None:
+ # aggregate visual_pos_masks and deepstack_visual_embeds
+ image_mask = image_mask[..., 0]
+ video_mask = video_mask[..., 0]
+ visual_pos_masks = image_mask | video_mask
+ deepstack_visual_embeds = []
+ image_mask_joint = image_mask[visual_pos_masks]
+ video_mask_joint = video_mask[visual_pos_masks]
+ for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
+ embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
+ embed_joint[image_mask_joint, :] = img_embed
+ embed_joint[video_mask_joint, :] = vid_embed
+ deepstack_visual_embeds.append(embed_joint)
+ elif image_mask is not None:
+ image_mask = image_mask[..., 0]
+ visual_pos_masks = image_mask
+ deepstack_visual_embeds = deepstack_image_embeds
+ elif video_mask is not None:
+ video_mask = video_mask[..., 0]
+ visual_pos_masks = video_mask
+ deepstack_visual_embeds = deepstack_video_embeds
+
+ if position_ids is None:
+ attention_mask_tensor = (
+ attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
+ )
+ if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
+ attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
+ # Only apply conversion for floating point tensors (inverted masks)
+ if attention_mask_tensor.dtype.is_floating_point:
+ attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
+ attention_mask_tensor = (1.0 - attention_mask_tensor).int()
+
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # When compiling, we can't check tensor values thus we check only input length
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
+ # models currently cannot do asssisted decoding
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
+ (input_ids is not None and input_ids.shape[1] != 1)
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
+ )
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
+ (cache_position is not None and cache_position[0] == 0)
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
+ )
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ video_grid_thw,
+ attention_mask=attention_mask_tensor,
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ delta = (
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
+ if cache_position is not None
+ else 0
+ )
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ if cache_position is not None: # otherwise `deltas` is an int `0`
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ visual_pos_masks=visual_pos_masks,
+ deepstack_visual_embeds=deepstack_visual_embeds,
+ **kwargs,
+ )
+
+ return Qwen3VLMoeModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ rope_deltas=self.rope_deltas,
+ )
+
+
+def load_balancing_loss_func(
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
+ num_experts: Optional[int] = None,
+ top_k=2,
+ attention_mask: Optional[torch.Tensor] = None,
+) -> Union[torch.Tensor, int]:
+ r"""
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+
+ Args:
+ gate_logits:
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+ shape [batch_size X sequence_length, num_experts].
+ num_experts:
+ Number of experts
+ top_k:
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
+ parameter.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention_mask used in forward function
+ shape [batch_size X sequence_length] if not None.
+
+ Returns:
+ The auxiliary loss.
+ """
+ if gate_logits is None or not isinstance(gate_logits, tuple):
+ return 0
+
+ if isinstance(gate_logits, tuple):
+ compute_device = gate_logits[0].device
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
+
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
+
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+ if attention_mask is None:
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
+ else:
+ batch_size, sequence_length = attention_mask.shape
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+ expert_attention_mask = (
+ attention_mask[None, :, :, None, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+ .reshape(-1, top_k, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+ expert_attention_mask, dim=0
+ )
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+ router_per_expert_attention_mask = (
+ attention_mask[None, :, :, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
+ .reshape(-1, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+ router_per_expert_attention_mask, dim=0
+ )
+
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ return overall_loss * num_experts
+
+
+class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+ _tied_weights_keys = ["lm_head.weight"]
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: Qwen3VLMoeConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Qwen3VLMoeModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw)
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ return self.model.get_image_features(pixel_values, image_grid_thw)
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def visual(self):
+ return self.model.visual
+
+ @check_model_inputs
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Qwen3VLMoeCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+
+ Example:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
+
+ >>> model = Qwen3VLMoeForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
+ },
+ {"type": "text", "text": "Describe this image in short."},
+ ],
+ }
+ ]
+
+ >>> # Preparation for inference
+ >>> inputs = processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_dict=True,
+ return_tensors="pt"
+ )
+ >>> inputs = inputs.to(model.device)
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=128)
+ >>> generated_ids_trimmed = [
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+ ]
+ >>> processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "A woman in a plaid shirt sits on a sandy beach at sunset, smiling as she gives a high-five to a yellow Labrador Retriever wearing a harness. The ocean waves roll in the background."
+ ```"""
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ aux_loss = None
+ if kwargs.get("output_router_logits", False):
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits,
+ self.config.text_config.num_experts,
+ self.config.text_config.num_experts_per_tok,
+ attention_mask,
+ )
+ if labels is not None:
+ loss += self.config.text_config.router_aux_loss_coef * aux_loss.to(
+ loss.device
+ ) # make sure to reside in the same device
+
+ return Qwen3VLMoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ # Qwen3VLMoe position_ids are prepareed with rope_deltas in forward
+ model_inputs["position_ids"] = None
+
+ if cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_values_videos"] = None
+
+ return model_inputs
+
+ def _get_image_nums_and_video_nums(
+ self,
+ input_ids: Optional[torch.LongTensor],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Returns:
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
+ """
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+
+ if inputs_embeds is not None:
+ vision_start_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ image_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ video_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ else:
+ vision_start_mask = input_ids == vision_start_token_id
+ image_mask = input_ids == image_token_id
+ video_mask = input_ids == video_token_id
+
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
+
+ return image_nums, video_nums
+
+ def _expand_inputs_for_generation(
+ self,
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: Optional[torch.LongTensor] = None,
+ **model_kwargs,
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
+ # Overwritten -- Support for expanding tensors without a batch size dimension
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
+ # image_grid_thw.shape[0] is sum(num_images for samples)
+
+ if expand_size == 1:
+ return input_ids, model_kwargs
+
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
+
+ def _expand_dict_for_generation_visual(dict_to_expand):
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
+ )
+
+ def _repeat_interleave_samples(x, lengths, repeat_times):
+ samples = torch.split(x, lengths)
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
+ return result
+
+ for key in dict_to_expand:
+ if key == "pixel_values":
+ # split images into samples
+ samples = torch.split(image_grid_thw, list(image_nums))
+ # compute the sequence length of images for each sample
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "image_grid_thw":
+ # get the num of images for each sample
+ lengths = list(image_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "pixel_values_videos":
+ samples = torch.split(video_grid_thw, list(video_nums))
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "video_grid_thw":
+ lengths = list(video_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "second_per_grid_ts":
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
+ )
+ return dict_to_expand
+
+ def _expand_dict_for_generation(dict_to_expand):
+ for key in dict_to_expand:
+ if (
+ key != "cache_position"
+ and dict_to_expand[key] is not None
+ and isinstance(dict_to_expand[key], torch.Tensor)
+ and key not in visual_keys
+ ):
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
+ return dict_to_expand
+
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
+
+ if is_encoder_decoder:
+ if model_kwargs.get("encoder_outputs") is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
+
+ return input_ids, model_kwargs
+
+
+__all__ = [
+ "Qwen3VLMoeVisionModel",
+ "Qwen3VLMoeForConditionalGeneration",
+ "Qwen3VLMoeModel",
+ "Qwen3VLMoePreTrainedModel",
+ "Qwen3VLMoeTextModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..72d3452bdc50a244b6cf3fdd9bd76dda5b2fc0f4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py
@@ -0,0 +1,551 @@
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen3-VL-MOE model."""
+
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, logging
+from ..qwen3_moe.modeling_qwen3_moe import (
+ Qwen3MoeDecoderLayer,
+ Qwen3MoePreTrainedModel,
+ Qwen3MoeRMSNorm,
+ load_balancing_loss_func,
+)
+from ..qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
+from ..qwen3_vl.modeling_qwen3_vl import (
+ Qwen3VLCausalLMOutputWithPast,
+ Qwen3VLForConditionalGeneration,
+ Qwen3VLModel,
+ Qwen3VLTextAttention,
+ Qwen3VLTextModel,
+ Qwen3VLVisionModel,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen3VLMoeTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a
+ Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 151936):
+ Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Qwen2MoeModel`]
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 5632):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 24):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 16):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 128000):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 5000000.0):
+ The base period of the RoPE embeddings.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ decoder_sparse_step (`int`, *optional*, defaults to 1):
+ The frequency of the MoE layer.
+ moe_intermediate_size (`int`, *optional*, defaults to 1408):
+ Intermediate size of the routed expert.
+ num_experts_per_tok (`int`, *optional*, defaults to 4):
+ Number of selected experts.
+ num_experts (`int`, *optional*, defaults to 60):
+ Number of routed experts.
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the topk probabilities.
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
+ The aux loss factor for the total loss.
+ mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
+ Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock
+ The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
+ If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ head_dim (`int`, *optional*):
+ The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
+
+ ```python
+ >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
+
+ >>> # Initializing a Qwen3VLMoe style configuration
+ >>> configuration = Qwen3VLMoeConfig()
+
+ >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
+ >>> model = Qwen3VLMoeForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_vl_moe_text"
+ base_config_key = "text_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `Qwen3VLMoe`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=151936,
+ hidden_size=2048,
+ intermediate_size=5632,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ hidden_act="silu",
+ max_position_embeddings=128000,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=5000000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ decoder_sparse_step=1,
+ moe_intermediate_size=1408,
+ num_experts_per_tok=4,
+ num_experts=60,
+ norm_topk_prob=True,
+ router_aux_loss_coef=0.001,
+ mlp_only_layers=None,
+ rope_scaling=None,
+ head_dim=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.rope_scaling = rope_scaling
+ self.head_dim = head_dim or hidden_size // num_attention_heads
+
+ rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
+
+ # MoE arguments
+ self.decoder_sparse_step = decoder_sparse_step
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_experts = num_experts
+ self.norm_topk_prob = norm_topk_prob
+ self.router_aux_loss_coef = router_aux_loss_coef
+ self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+class Qwen3VLMoeVisionConfig(Qwen3VLVisionConfig):
+ pass
+
+
+class Qwen3VLMoeConfig(Qwen3VLConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a
+ Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 151655):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 151656):
+ The video token index to encode the image prompt.
+ vision_start_token_id (`int`, *optional*, defaults to 151652):
+ The start token index to encode the image prompt.
+ vision_end_token_id (`int`, *optional*, defaults to 151653):
+ The end token index to encode the image prompt.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie the word embeddings.
+
+ ```python
+ >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
+
+ >>> # Initializing a Qwen3-VL-MOE style configuration
+ >>> configuration = Qwen3VLMoeConfig()
+
+ >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
+ >>> model = Qwen3VLMoeForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_vl_moe"
+ sub_configs = {"vision_config": Qwen3VLMoeVisionConfig, "text_config": Qwen3VLMoeTextConfig}
+
+
+class Qwen3VLMoeTextRMSNorm(Qwen3MoeRMSNorm):
+ pass
+
+
+class Qwen3VLMoeTextExperts(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_experts = config.num_experts
+ self.intermediate_size = config.moe_intermediate_size
+ self.hidden_size = config.hidden_size
+ self.expert_dim = self.intermediate_size
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
+ self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(
+ self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ When training it is more efficient to just loop over the experts and compute the output for each expert
+ as otherwise the memory would explode.
+
+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
+
+ Args:
+ hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
+ router_indices (torch.Tensor): (batch_size * token_num, top_k)
+ Returns:
+ torch.Tensor
+ """
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
+ if self.training:
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+ with torch.no_grad():
+ expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts)
+ expert_mask = expert_mask.permute(2, 1, 0)
+ # we sum on the top_k and on the sequence length to get which experts
+ # are hit this time around
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+ for expert_idx in expert_hit[:]:
+ with torch.no_grad():
+ _, token_idx = torch.where(expert_mask[expert_idx[0]])
+ current_state = hidden_states[token_idx]
+ gate_up = current_state @ self.gate_up_proj[expert_idx]
+ gate, up = gate_up.chunk(2, dim=-1)
+ gated_output = up * self.act_fn(gate)
+ out = gated_output @ self.down_proj[expert_idx]
+ weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
+ else:
+ hidden_states = hidden_states.repeat(self.num_experts, 1)
+ hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
+ gate_up = torch.bmm(hidden_states, self.gate_up_proj)
+ gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
+ next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
+ next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size)
+ next_states = (
+ next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None]
+ )
+ next_states = next_states.sum(dim=0)
+ return next_states
+
+
+class Qwen3VLMoeTextSparseMoeBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_experts = config.num_experts
+ self.top_k = config.num_experts_per_tok
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
+ self.experts = Qwen3VLMoeTextExperts(config)
+
+ # since all the models use norm_topk_prob, we don't need to have a extra check for it
+ # self.norm_topk_prob = config.norm_topk_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states.reshape(-1, self.hidden_size)
+ router_logits = self.gate(hidden_states)
+ routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
+ routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
+ routing_weights = routing_weights.to(hidden_states.dtype)
+ router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights)
+ hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size)
+ routed_out = self.experts(hidden_states, router_weights, router_indices)
+ return routed_out, router_logits
+
+
+class Qwen3VLMoeTextAttention(Qwen3VLTextAttention):
+ pass
+
+
+class Qwen3VLMoeTextDecoderLayer(Qwen3MoeDecoderLayer):
+ pass
+
+
+class Qwen3VLMoePreTrainedModel(Qwen3MoePreTrainedModel):
+ config: Qwen3VLMoeConfig
+ _no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLMoeVisionBlock"]
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ PreTrainedModel._init_weights(self, module)
+ if hasattr(self.config, "initializer_range"):
+ std = self.config.initializer_range
+ else:
+ std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
+ if isinstance(module, Qwen3VLMoeTextExperts):
+ module.gate_up_proj.data.normal_(mean=0.0, std=std)
+ module.down_proj.data.normal_(mean=0.0, std=std)
+
+
+class Qwen3VLMoeVisionModel(Qwen3VLVisionModel):
+ pass
+
+
+class Qwen3VLMoeTextModel(Qwen3VLTextModel):
+ pass
+
+
+class Qwen3VLMoeCausalLMOutputWithPast(Qwen3VLCausalLMOutputWithPast):
+ aux_loss: Optional[torch.FloatTensor] = None
+
+
+class Qwen3VLMoeModel(Qwen3VLModel):
+ pass
+
+
+class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+
+ Example:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
+
+ >>> model = Qwen3VLMoeForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
+ },
+ {"type": "text", "text": "Describe this image in short."},
+ ],
+ }
+ ]
+
+ >>> # Preparation for inference
+ >>> inputs = processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_dict=True,
+ return_tensors="pt"
+ )
+ >>> inputs = inputs.to(model.device)
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=128)
+ >>> generated_ids_trimmed = [
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+ ]
+ >>> processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "A woman in a plaid shirt sits on a sandy beach at sunset, smiling as she gives a high-five to a yellow Labrador Retriever wearing a harness. The ocean waves roll in the background."
+ ```"""
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ aux_loss = None
+ if kwargs.get("output_router_logits", False):
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits,
+ self.config.text_config.num_experts,
+ self.config.text_config.num_experts_per_tok,
+ attention_mask,
+ )
+ if labels is not None:
+ loss += self.config.text_config.router_aux_loss_coef * aux_loss.to(
+ loss.device
+ ) # make sure to reside in the same device
+
+ return Qwen3VLMoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+
+__all__ = [
+ "Qwen3VLMoeConfig",
+ "Qwen3VLMoeTextConfig",
+ "Qwen3VLMoeVisionModel",
+ "Qwen3VLMoeForConditionalGeneration",
+ "Qwen3VLMoeModel",
+ "Qwen3VLMoePreTrainedModel",
+ "Qwen3VLMoeTextModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a8f135ba454f08f1773239837a2db627ce075c6
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_rag import *
+ from .modeling_rag import *
+ from .modeling_tf_rag import *
+ from .retrieval_rag import *
+ from .tokenization_rag import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/configuration_rag.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/configuration_rag.py
new file mode 100644
index 0000000000000000000000000000000000000000..dca4eb04d3f0ad42432ee23d0471471ccb78dfc0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/configuration_rag.py
@@ -0,0 +1,186 @@
+# coding=utf-8
+# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""RAG model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import add_start_docstrings
+
+
+RAG_CONFIG_DOC = r"""
+ [`RagConfig`] stores the configuration of a *RagModel*. Configuration objects inherit from [`PretrainedConfig`] and
+ can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ title_sep (`str`, *optional*, defaults to `" / "`):
+ Separator inserted between the title and the text of the retrieved document when calling [`RagRetriever`].
+ doc_sep (`str`, *optional*, defaults to `" // "`):
+ Separator inserted between the text of the retrieved document and the original input when calling
+ [`RagRetriever`].
+ n_docs (`int`, *optional*, defaults to 5):
+ Number of documents to retrieve.
+ max_combined_length (`int`, *optional*, defaults to 300):
+ Max length of contextualized input returned by [`~RagRetriever.__call__`].
+ retrieval_vector_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the document embeddings indexed by [`RagRetriever`].
+ retrieval_batch_size (`int`, *optional*, defaults to 8):
+ Retrieval batch size, defined as the number of queries issues concurrently to the faiss index encapsulated
+ [`RagRetriever`].
+ dataset (`str`, *optional*, defaults to `"wiki_dpr"`):
+ A dataset identifier of the indexed dataset in HuggingFace Datasets (list all available datasets and ids
+ using `datasets.list_datasets()`).
+ dataset_split (`str`, *optional*, defaults to `"train"`)
+ Which split of the `dataset` to load.
+ index_name (`str`, *optional*, defaults to `"compressed"`)
+ The index name of the index associated with the `dataset`. One can choose between `"legacy"`, `"exact"` and
+ `"compressed"`.
+ index_path (`str`, *optional*)
+ The path to the serialized faiss index on disk.
+ passages_path (`str`, *optional*):
+ A path to text passages compatible with the faiss index. Required if using
+ [`~models.rag.retrieval_rag.LegacyIndex`]
+ use_dummy_dataset (`bool`, *optional*, defaults to `False`)
+ Whether to load a "dummy" variant of the dataset specified by `dataset`.
+ label_smoothing (`float`, *optional*, defaults to 0.0):
+ Only relevant if `return_loss` is set to `True`. Controls the `epsilon` parameter value for label smoothing
+ in the loss calculation. If set to 0, no label smoothing is performed.
+ do_marginalize (`bool`, *optional*, defaults to `False`):
+ If `True`, the logits are marginalized over all documents by making use of
+ `torch.nn.functional.log_softmax`.
+ reduce_loss (`bool`, *optional*, defaults to `False`):
+ Whether or not to reduce the NLL loss using the `torch.Tensor.sum` operation.
+ do_deduplication (`bool`, *optional*, defaults to `True`):
+ Whether or not to deduplicate the generations from different context documents for a given input. Has to be
+ set to `False` if used while training with distributed backend.
+ exclude_bos_score (`bool`, *optional*, defaults to `False`):
+ Whether or not to disregard the BOS token when computing the loss.
+ output_retrieved(`bool`, *optional*, defaults to `False`):
+ If set to `True`, `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
+ `context_attention_mask` are returned. See returned tensors for more detail.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ forced_eos_token_id (`int`, *optional*):
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
+ `eos_token_id`.
+"""
+
+
+@add_start_docstrings(RAG_CONFIG_DOC)
+class RagConfig(PretrainedConfig):
+ model_type = "rag"
+ has_no_defaults_at_init = True
+
+ def __init__(
+ self,
+ vocab_size=None,
+ is_encoder_decoder=True,
+ prefix=None,
+ bos_token_id=None,
+ pad_token_id=None,
+ eos_token_id=None,
+ decoder_start_token_id=None,
+ title_sep=" / ",
+ doc_sep=" // ",
+ n_docs=5,
+ max_combined_length=300,
+ retrieval_vector_size=768,
+ retrieval_batch_size=8,
+ dataset="wiki_dpr",
+ dataset_split="train",
+ index_name="compressed",
+ index_path=None,
+ passages_path=None,
+ use_dummy_dataset=False,
+ reduce_loss=False,
+ label_smoothing=0.0,
+ do_deduplication=True,
+ exclude_bos_score=False,
+ do_marginalize=False,
+ output_retrieved=False,
+ use_cache=True,
+ forced_eos_token_id=None,
+ dataset_revision=None,
+ **kwargs,
+ ):
+ super().__init__(
+ bos_token_id=bos_token_id,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ decoder_start_token_id=decoder_start_token_id,
+ forced_eos_token_id=forced_eos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ prefix=prefix,
+ vocab_size=vocab_size,
+ **kwargs,
+ )
+ if "question_encoder" not in kwargs or "generator" not in kwargs:
+ raise ValueError(
+ f"A configuration of type {self.model_type} cannot be instantiated because "
+ f"both `question_encoder` and `generator` sub-configurations were not passed, only {kwargs}"
+ )
+ question_encoder_config = kwargs.pop("question_encoder")
+ question_encoder_model_type = question_encoder_config.pop("model_type")
+ decoder_config = kwargs.pop("generator")
+ decoder_model_type = decoder_config.pop("model_type")
+
+ from ..auto.configuration_auto import AutoConfig
+
+ self.question_encoder = AutoConfig.for_model(question_encoder_model_type, **question_encoder_config)
+ self.generator = AutoConfig.for_model(decoder_model_type, **decoder_config)
+
+ self.reduce_loss = reduce_loss
+ self.label_smoothing = label_smoothing
+ self.exclude_bos_score = exclude_bos_score
+ self.do_marginalize = do_marginalize
+
+ self.title_sep = title_sep
+ self.doc_sep = doc_sep
+ self.n_docs = n_docs
+ self.max_combined_length = max_combined_length
+
+ self.dataset = dataset
+ self.dataset_split = dataset_split
+ self.index_name = index_name
+
+ self.retrieval_vector_size = retrieval_vector_size
+ self.retrieval_batch_size = retrieval_batch_size
+ self.passages_path = passages_path
+ self.index_path = index_path
+ self.use_dummy_dataset = use_dummy_dataset
+ self.dataset_revision = dataset_revision
+
+ self.output_retrieved = output_retrieved
+
+ self.do_deduplication = do_deduplication
+
+ self.use_cache = use_cache
+
+ if self.forced_eos_token_id is None:
+ self.forced_eos_token_id = getattr(self.generator, "forced_eos_token_id", None)
+
+ @classmethod
+ def from_question_encoder_generator_configs(
+ cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
+ ) -> PretrainedConfig:
+ r"""
+ Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and
+ decoder model configuration.
+
+ Returns:
+ [`EncoderDecoderConfig`]: An instance of a configuration object
+ """
+ return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs)
+
+
+__all__ = ["RagConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/modeling_rag.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/modeling_rag.py
new file mode 100644
index 0000000000000000000000000000000000000000..25c2d66dd7015899b7edde8a77fbe06ba40f68a8
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/modeling_rag.py
@@ -0,0 +1,1665 @@
+# coding=utf-8
+# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""RAG model implementation."""
+
+import copy
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...cache_utils import Cache, EncoderDecoderCache
+from ...configuration_utils import PretrainedConfig
+from ...generation import GenerationConfig, GenerationMixin, LogitsProcessorList, StoppingCriteriaList
+from ...modeling_outputs import ModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging
+from .configuration_rag import RagConfig
+from .retrieval_rag import RagRetriever
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for retriever augmented marginalized models outputs.
+ """
+)
+class RetrievAugLMMarginOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
+ each vocabulary token.
+ doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
+ Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
+ `question_encoder_last_hidden_state`.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
+ (see `past_key_values` input) to speed up sequential decoding.
+ retrieved_doc_embeds (`torch.FloatTensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
+ Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
+ the `doc_scores`.
+ retrieved_doc_ids (`torch.LongTensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
+ The indexes of the embedded documents retrieved by the retriever.
+ context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
+ context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever.
+ question_encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
+ model.
+ question_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
+ question_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
+ average in the self-attention heads.
+ generator_enc_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
+ generator_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
+ generator_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
+ average in the self-attention heads.
+ generator_dec_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
+ generator_dec_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
+ average in the self-attention heads.
+ generator_cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ doc_scores: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ retrieved_doc_embeds: Optional[torch.FloatTensor] = None
+ retrieved_doc_ids: Optional[torch.LongTensor] = None
+ context_input_ids: Optional[torch.LongTensor] = None
+ context_attention_mask: Optional[torch.LongTensor] = None
+ question_encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ question_enc_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ question_enc_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ generator_enc_last_hidden_state: Optional[torch.FloatTensor] = None
+ generator_enc_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ generator_enc_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ generator_dec_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ generator_dec_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ generator_cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring
+class RetrievAugLMOutput(ModelOutput):
+ r"""
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
+ each vocabulary token.
+ doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
+ Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
+ `question_encoder_last_hidden_state`.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
+ (see `past_key_values` input) to speed up sequential decoding.
+ retrieved_doc_embeds (`torch.FloatTensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
+ Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
+ the `doc_scores`.
+ retrieved_doc_ids (`torch.LongTensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
+ The indexes of the embedded documents retrieved by the retriever.
+ context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
+ context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever.
+ question_encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
+ model.
+ question_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
+ question_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
+ average in the self-attention heads.
+ generator_enc_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
+ generator_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
+ generator_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
+ average in the self-attention heads.
+ generator_dec_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
+ generator_dec_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
+ average in the self-attention heads.
+ generator_cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ """
+
+ logits: Optional[torch.FloatTensor] = None
+ doc_scores: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ retrieved_doc_embeds: Optional[torch.FloatTensor] = None
+ retrieved_doc_ids: Optional[torch.LongTensor] = None
+ context_input_ids: Optional[torch.LongTensor] = None
+ context_attention_mask: Optional[torch.LongTensor] = None
+ question_encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ question_enc_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ question_enc_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ generator_enc_last_hidden_state: Optional[torch.FloatTensor] = None
+ generator_enc_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ generator_enc_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ generator_dec_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ generator_dec_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ generator_cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ RAG models were released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP
+ Tasks](https://huggingface.co/papers/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.
+
+ RAG is a retriever augmented model and encapsulate three components: a question encoder, a dataset retriever and a
+ generator, the encoder and generator are trainable while the retriever is just an indexed dataset.
+ """
+)
+@auto_docstring
+class RagPreTrainedModel(PreTrainedModel):
+ config: RagConfig
+ base_model_prefix = "rag"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ @classmethod
+ def from_pretrained_question_encoder_generator(
+ cls,
+ question_encoder_pretrained_model_name_or_path: Optional[str] = None,
+ generator_pretrained_model_name_or_path: Optional[str] = None,
+ retriever: RagRetriever = None,
+ **kwargs,
+ ) -> PreTrainedModel:
+ r"""
+ Instantiates an question encoder and a generator from one or two base classes of the library from pretrained
+ model checkpoints.
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you need to first set it back in training mode with `model.train()`.
+
+ Params:
+ question_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
+ Information necessary to initiate the question encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+
+ generator_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
+ Information necessary to initiate the generator. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+
+ model_args (remaining positional arguments, *optional*):
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
+ retriever ([`RagRetriever`], *optional*):
+ The retriever to use.
+ kwwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`).
+
+ - To update the question_encoder configuration, use the prefix *question_encoder_* for each
+ configuration parameter.
+ - To update the generator configuration, use the prefix *generator_* for each configuration parameter.
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+ Example:
+
+ ```python
+ >>> from transformers import RagModel
+
+ >>> # initialize a RAG from two pretrained models.
+ >>> model = RagModel.from_pretrained_question_encoder_generator(
+ ... "facebook/dpr-question_encoder-single-nq-base", "google-t5/t5-small"
+ ... )
+ >>> # saving model after fine-tuning
+ >>> model.save_pretrained("./rag")
+ >>> # load fine-tuned model
+ >>> model = RagModel.from_pretrained("./rag")
+ ```"""
+
+ kwargs_question_encoder = {
+ argument[len("question_encoder_") :]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("question_encoder_")
+ }
+
+ kwargs_generator = {
+ argument[len("generator_") :]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("generator_")
+ }
+
+ # remove question_encoder, generator kwargs from kwargs
+ for key in kwargs_question_encoder:
+ del kwargs["question_encoder_" + key]
+ for key in kwargs_generator:
+ del kwargs["generator_" + key]
+
+ # Load and initialize the question_encoder and generator
+ # The distinction between question_encoder and generator at the model level is made
+ # by the value of the flag `is_generator` that we need to set correctly.
+ question_encoder = kwargs_question_encoder.pop("model", None)
+ if question_encoder is None:
+ assert question_encoder_pretrained_model_name_or_path is not None, (
+ "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to"
+ " be defined"
+ )
+ from ..auto.modeling_auto import AutoModel
+
+ if "config" not in kwargs_question_encoder:
+ from ..auto.configuration_auto import AutoConfig
+
+ question_encoder_config, kwargs_question_encoder = AutoConfig.from_pretrained(
+ question_encoder_pretrained_model_name_or_path,
+ **kwargs_question_encoder,
+ return_unused_kwargs=True,
+ )
+ kwargs_question_encoder["config"] = question_encoder_config
+
+ question_encoder = AutoModel.from_pretrained(
+ question_encoder_pretrained_model_name_or_path, **kwargs_question_encoder
+ )
+
+ generator = kwargs_generator.pop("model", None)
+ if generator is None:
+ assert generator_pretrained_model_name_or_path is not None, (
+ "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has"
+ " to be defined"
+ )
+ from ..auto.modeling_auto import AutoModelForSeq2SeqLM
+
+ if "config" not in kwargs_generator:
+ from ..auto.configuration_auto import AutoConfig
+
+ generator_config, kwargs_generator = AutoConfig.from_pretrained(
+ generator_pretrained_model_name_or_path, **kwargs_generator, return_unused_kwargs=True
+ )
+
+ kwargs_generator["config"] = generator_config
+
+ generator = AutoModelForSeq2SeqLM.from_pretrained(
+ generator_pretrained_model_name_or_path, **kwargs_generator
+ )
+
+ # instantiate config with corresponding kwargs
+ config = kwargs.get("config")
+ if config is None:
+ config = RagConfig.from_question_encoder_generator_configs(
+ question_encoder.config, generator.config, **kwargs
+ )
+
+ return cls(question_encoder=question_encoder, generator=generator, config=config, retriever=retriever)
+
+
+@auto_docstring
+class RagModel(RagPreTrainedModel):
+ def __init__(
+ self,
+ config: Optional[PretrainedConfig] = None,
+ question_encoder: Optional[PreTrainedModel] = None,
+ generator: Optional[PreTrainedModel] = None,
+ retriever: Optional[RagRetriever] = None, # or maybe just use a `set_retriever(...)` method
+ **kwargs,
+ ):
+ r"""
+ question_encoder (`PreTrainedModel`, *optional*):
+ The model responsible for encoding the question into hidden states for retrieval.
+ generator (`PreTrainedModel`, *optional*):
+ The model responsible for generating text based on retrieved documents.
+ retriever (`RagRetriever`, *optional*):
+ The component responsible for retrieving documents from a knowledge base given the encoded question.
+ """
+ assert config is not None or (question_encoder is not None and generator is not None), (
+ "Either a configuration or an question_encoder and a generator has to be provided."
+ )
+
+ if config is None:
+ config = RagConfig.from_question_encoder_generator_configs(
+ question_encoder.config, generator.config, **kwargs
+ )
+ else:
+ assert isinstance(config, self.config_class), f"config: {config} has to be of type {self.config_class}"
+ super().__init__(config)
+ if question_encoder is None:
+ from ..auto.modeling_auto import AutoModel
+
+ question_encoder = AutoModel.from_config(config.question_encoder)
+
+ if generator is None:
+ from ..auto.modeling_auto import AutoModelForSeq2SeqLM
+
+ generator = AutoModelForSeq2SeqLM.from_config(config.generator)
+
+ self.retriever = retriever
+ if self.retriever is not None:
+ assert isinstance(retriever, RagRetriever), (
+ f"`self.retriever` is of type {type(self.retriever)}, but should be of type `RagRetriever`"
+ )
+ self.retriever = retriever
+
+ self.question_encoder = question_encoder
+ self.generator = generator
+
+ self.ctx_encoder = None
+ self.context_encoder_training = False
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ doc_scores: Optional[torch.FloatTensor] = None,
+ context_input_ids: Optional[torch.LongTensor] = None,
+ context_attention_mask: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_retrieved: Optional[bool] = None,
+ n_docs: Optional[int] = None,
+ ) -> Union[tuple[torch.Tensor], RetrievAugLMOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
+ which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
+ obtain the indices.
+
+ [What are input IDs?](../glossary#input-ids)
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
+ Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
+ *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
+ sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
+ generator's encoder.
+
+ Used by the ([`RagModel`]) model during decoding.
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Provide for generation tasks. `None` by default, construct as per instructions for the generator model
+ you're using with your RAG instance.
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
+ Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
+ `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
+ has to be provided to the forward pass. `doc_scores` can be computed via
+ `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
+ context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
+ the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
+ context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
+ Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
+ provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
+ output_retrieved (`bool`, *optional*):
+ Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
+ `context_attention_mask`. See returned tensors for more detail.
+ n_docs (`int`, *optional*):
+ The number of documents to retrieve.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, RagRetriever, RagModel
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-base")
+ >>> retriever = RagRetriever.from_pretrained(
+ ... "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True
+ ... )
+ >>> # initialize with RagRetriever to do everything in one forward call
+ >>> model = RagModel.from_pretrained("facebook/rag-token-base", retriever=retriever)
+
+ >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
+ >>> outputs = model(input_ids=inputs["input_ids"])
+ ```"""
+ n_docs = n_docs if n_docs is not None else self.config.n_docs
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_retrieved = output_retrieved if output_retrieved is not None else self.config.output_retrieved
+
+ # whether retriever has to be used
+ has_to_retrieve = (
+ self.retriever is not None
+ and (context_input_ids is None or context_attention_mask is None or doc_scores is None)
+ and encoder_outputs is None
+ )
+ # encoder_outputs are pre-computed during RAG-token generation
+ if encoder_outputs is None:
+ if has_to_retrieve:
+ question_enc_outputs = self.question_encoder(
+ input_ids, attention_mask=attention_mask, return_dict=True
+ )
+ question_encoder_last_hidden_state = question_enc_outputs[0] # hidden states of question encoder
+
+ retriever_outputs = self.retriever(
+ input_ids,
+ question_encoder_last_hidden_state.detach().to(device="cpu", dtype=torch.float32).numpy(),
+ prefix=self.generator.config.prefix,
+ n_docs=n_docs,
+ return_tensors="pt",
+ )
+ if self.context_encoder_training:
+ (
+ context_input_ids,
+ context_attention_mask,
+ retrieved_doc_embeds,
+ retrieved_doc_input_ids,
+ retrieved_doc_attention_mask,
+ retrieved_doc_ids,
+ ) = (
+ retriever_outputs["context_input_ids"],
+ retriever_outputs["context_attention_mask"],
+ retriever_outputs["retrieved_doc_embeds"],
+ retriever_outputs["tokenized_doc_ids"],
+ retriever_outputs["tokenized_doc_attention_mask"],
+ retriever_outputs["doc_ids"],
+ )
+
+ context_input_ids = context_input_ids.to(input_ids)
+ context_attention_mask = context_attention_mask.to(input_ids)
+
+ retrieved_doc_input_ids = retrieved_doc_input_ids.to(input_ids)
+ retrieved_doc_attention_mask = retrieved_doc_attention_mask.to(input_ids)
+ retrieved_doc_embeds = self.ctx_encoder(
+ retrieved_doc_input_ids, attention_mask=retrieved_doc_attention_mask, return_dict=True
+ ).pooler_output
+ retrieved_doc_embeds = retrieved_doc_embeds.view(
+ -1, n_docs, question_encoder_last_hidden_state.shape[1]
+ ) # reshaping
+
+ # compute doc_scores involving ctx_encoder
+ doc_scores = torch.bmm(
+ question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)
+ ).squeeze(1)
+
+ else:
+ context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = (
+ retriever_outputs["context_input_ids"],
+ retriever_outputs["context_attention_mask"],
+ retriever_outputs["retrieved_doc_embeds"],
+ retriever_outputs["doc_ids"],
+ )
+
+ # set to correct device
+ retrieved_doc_embeds = retrieved_doc_embeds.to(question_encoder_last_hidden_state)
+ context_input_ids = context_input_ids.to(input_ids)
+ context_attention_mask = context_attention_mask.to(input_ids)
+
+ # compute doc_scores
+ doc_scores = torch.bmm(
+ question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)
+ ).squeeze(1)
+ else:
+ assert context_input_ids is not None, (
+ "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can"
+ " set a retriever using the `set_retriever(...)` function."
+ )
+ assert context_attention_mask is not None, (
+ "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you"
+ " can set a retriever using the `set_retriever(...)` function."
+ )
+ assert doc_scores is not None, (
+ "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a"
+ " retriever using the `set_retriever(...)` function."
+ )
+
+ assert doc_scores is not None, (
+ "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function."
+ )
+
+ assert (doc_scores.shape[1] % n_docs) == 0, (
+ f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
+ f" {context_input_ids.shape[0]}."
+ )
+
+ # Decoder input without context documents
+ if decoder_input_ids is not None:
+ decoder_input_ids = decoder_input_ids.repeat_interleave(n_docs, dim=0)
+
+ if decoder_attention_mask is not None:
+ decoder_attention_mask = decoder_attention_mask.repeat_interleave(n_docs, dim=0)
+
+ gen_outputs = self.generator(
+ input_ids=context_input_ids,
+ attention_mask=context_attention_mask,
+ encoder_outputs=encoder_outputs,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ return_dict=True,
+ )
+
+ if not has_to_retrieve:
+ question_encoder_last_hidden_state = None
+ question_enc_hidden_states = None
+ question_enc_attentions = None
+ retrieved_doc_embeds = None
+ retrieved_doc_ids = None
+ else:
+ question_enc_hidden_states = question_enc_outputs.hidden_states
+ question_enc_attentions = question_enc_outputs.attentions
+
+ if not has_to_retrieve or not output_retrieved:
+ # don't output retrieved docs
+ context_input_ids = (None,)
+ context_attention_mask = None
+ retrieved_doc_embeds = None
+ retrieved_doc_ids = None
+
+ return RetrievAugLMOutput(
+ logits=gen_outputs.logits,
+ doc_scores=doc_scores,
+ past_key_values=gen_outputs.past_key_values,
+ context_input_ids=context_input_ids,
+ context_attention_mask=context_attention_mask,
+ retrieved_doc_embeds=retrieved_doc_embeds,
+ retrieved_doc_ids=retrieved_doc_ids,
+ question_encoder_last_hidden_state=question_encoder_last_hidden_state,
+ question_enc_hidden_states=question_enc_hidden_states,
+ question_enc_attentions=question_enc_attentions,
+ generator_enc_last_hidden_state=gen_outputs.encoder_last_hidden_state,
+ generator_enc_hidden_states=gen_outputs.encoder_hidden_states,
+ generator_enc_attentions=gen_outputs.encoder_attentions,
+ generator_dec_hidden_states=gen_outputs.decoder_hidden_states,
+ generator_dec_attentions=gen_outputs.decoder_attentions,
+ generator_cross_attentions=gen_outputs.cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ A RAG-sequence model implementation. It performs RAG-sequence specific marginalization in the forward pass.
+ """
+)
+class RagSequenceForGeneration(RagPreTrainedModel):
+ def __init__(
+ self,
+ config: Optional[PretrainedConfig] = None,
+ question_encoder: Optional[PreTrainedModel] = None,
+ generator: Optional[PreTrainedModel] = None,
+ retriever: Optional[RagRetriever] = None,
+ **kwargs,
+ ):
+ r"""
+ question_encoder (`PreTrainedModel`, *optional*):
+ The model responsible for encoding the question into hidden states for retrieval.
+ generator (`PreTrainedModel`, *optional*):
+ The model responsible for generating text based on retrieved documents.
+ retriever (`RagRetriever`, *optional*):
+ The component responsible for retrieving documents from a knowledge base given the encoded question.
+ """
+ assert config is not None or (question_encoder is not None and generator is not None), (
+ "Either a configuration or an encoder and a generator has to be provided."
+ )
+
+ if config is None:
+ config = RagConfig.from_question_encoder_generator_configs(
+ question_encoder.config, generator.config, **kwargs
+ )
+ super().__init__(config)
+
+ # instantiate model
+ self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever)
+
+ def set_retriever(self, retriever: RagRetriever):
+ self.rag.retriever = retriever
+
+ def set_context_encoder_for_training(self, ctx_encoder: PreTrainedModel):
+ self.rag.context_encoder_training = True
+ self.rag.ctx_encoder = ctx_encoder
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ context_input_ids: Optional[torch.LongTensor] = None,
+ context_attention_mask: Optional[torch.LongTensor] = None,
+ doc_scores: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_retrieved: Optional[bool] = None,
+ exclude_bos_score: Optional[bool] = None,
+ reduce_loss: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ n_docs: Optional[int] = None,
+ **kwargs, # needs kwargs for generation
+ ) -> RetrievAugLMMarginOutput:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
+ which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
+ obtain the indices.
+
+ [What are input IDs?](../glossary#input-ids)
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
+ Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
+ *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
+ sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
+ generator's encoder.
+
+ Used by the ([`RagModel`]) model during decoding.
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Provide for generation tasks. `None` by default, construct as per instructions for the generator model
+ you're using with your RAG instance.
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
+ the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
+ context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
+ Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
+ provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
+ doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
+ Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
+ `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
+ has to be provided to the forward pass. `doc_scores` can be computed via
+ `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
+ output_retrieved (`bool`, *optional*):
+ Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
+ `context_attention_mask`. See returned tensors for more detail.
+ exclude_bos_score (`bool`, *optional*):
+ Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing
+ the loss.
+ reduce_loss (`bool`, *optional*):
+ Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `torch.Tensor.sum`
+ operation.
+ n_docs (`int`, *optional*):
+ The number of documents to retrieve.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, RagRetriever, RagSequenceForGeneration
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
+ >>> retriever = RagRetriever.from_pretrained(
+ ... "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
+ ... )
+ >>> # initialize with RagRetriever to do everything in one forward call
+ >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
+
+ >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
+ >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
+ >>> input_ids = inputs["input_ids"]
+ >>> labels = targets["input_ids"]
+ >>> outputs = model(input_ids=input_ids, labels=labels)
+
+ >>> # or use retriever separately
+ >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
+ >>> # 1. Encode
+ >>> question_hidden_states = model.question_encoder(input_ids)[0]
+ >>> # 2. Retrieve
+ >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
+ >>> doc_scores = torch.bmm(
+ ... question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)
+ ... ).squeeze(1)
+ >>> # 3. Forward to generator
+ >>> outputs = model(
+ ... context_input_ids=docs_dict["context_input_ids"],
+ ... context_attention_mask=docs_dict["context_attention_mask"],
+ ... doc_scores=doc_scores,
+ ... decoder_input_ids=labels,
+ ... )
+ ```"""
+ n_docs = n_docs if n_docs is not None else self.config.n_docs
+ exclude_bos_score = exclude_bos_score if exclude_bos_score is not None else self.config.exclude_bos_score
+ reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss
+
+ if labels is not None:
+ if decoder_input_ids is None:
+ decoder_input_ids = labels
+ use_cache = False
+
+ outputs = self.rag(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ encoder_outputs=encoder_outputs,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ context_input_ids=context_input_ids,
+ context_attention_mask=context_attention_mask,
+ doc_scores=doc_scores,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_retrieved=output_retrieved,
+ n_docs=n_docs,
+ )
+
+ loss = None
+ if labels is not None:
+ loss = self.get_nll(
+ outputs.logits,
+ outputs.doc_scores,
+ decoder_input_ids,
+ reduce_loss=reduce_loss,
+ epsilon=self.config.label_smoothing,
+ exclude_bos_score=exclude_bos_score,
+ n_docs=n_docs,
+ )
+
+ return RetrievAugLMMarginOutput(
+ loss=loss,
+ logits=outputs.logits,
+ doc_scores=outputs.doc_scores,
+ past_key_values=outputs.past_key_values,
+ context_input_ids=outputs.context_input_ids,
+ context_attention_mask=outputs.context_attention_mask,
+ retrieved_doc_embeds=outputs.retrieved_doc_embeds,
+ retrieved_doc_ids=outputs.retrieved_doc_ids,
+ question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state,
+ question_enc_hidden_states=outputs.question_enc_hidden_states,
+ question_enc_attentions=outputs.question_enc_attentions,
+ generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state,
+ generator_enc_hidden_states=outputs.generator_enc_hidden_states,
+ generator_enc_attentions=outputs.generator_enc_attentions,
+ generator_dec_hidden_states=outputs.generator_dec_hidden_states,
+ generator_dec_attentions=outputs.generator_dec_attentions,
+ generator_cross_attentions=outputs.generator_cross_attentions,
+ )
+
+ @property
+ def retriever(self):
+ return self.rag.retriever
+
+ @property
+ def generator(self):
+ return self.rag.generator
+
+ @property
+ def question_encoder(self):
+ return self.rag.question_encoder
+
+ @torch.no_grad()
+ def generate(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ context_input_ids: Optional[torch.LongTensor] = None,
+ context_attention_mask: Optional[torch.LongTensor] = None,
+ doc_scores: Optional[torch.FloatTensor] = None,
+ do_deduplication: Optional[bool] = None, # defaults to True
+ num_return_sequences: Optional[int] = None, # defaults to 1
+ num_beams: Optional[int] = None, # defaults to 1
+ n_docs: Optional[int] = None,
+ **model_kwargs,
+ ) -> torch.LongTensor:
+ """
+ Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation
+ for more information on how to set other generate input parameters.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ The sequence used as a prompt for the generation. If `input_ids` is not passed, then
+ `context_input_ids` has to be provided.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Input IDs post-processed from the retrieved documents and the question encoder input_ids by the
+ retriever.
+ context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever.
+
+ If the model is not initialized with a `retriever` or `input_ids` is not given, `context_input_ids` and
+ `context_attention_mask` have to be provided to the forward pass. They are returned by
+ [`~RagRetriever.__call__`].
+ doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
+ Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
+ `question_encoder_last_hidden_state`.
+
+ If the model is not initialized with a `retriever` or `input_ids` is not given, `doc_scores` has to be
+ provided to the forward pass. `doc_scores` are returned by [`~RagRetriever.__call__`].
+ do_deduplication (`bool`, *optional*):
+ Whether or not to deduplicate the generations from different context documents for a given input. Has
+ to be set to `False` if used while training with distributed backend.
+ num_return_sequences(`int`, *optional*, defaults to 1):
+ The number of independently computed returned sequences for each element in the batch. Note that this
+ is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function,
+ where we set `num_return_sequences` to `num_beams`.
+ num_beams (`int`, *optional*, defaults to 1):
+ Number of beams for beam search. 1 means no beam search.
+ n_docs (`int`, *optional*, defaults to `config.n_docs`)
+ Number of documents to retrieve and/or number of documents for which to generate an answer.
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional kwargs will be passed to [`~generation.GenerationMixin.generate`].
+
+ Return:
+ `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
+ sequences. The second dimension (sequence length) is either equal to `max_length` or shorter if all batches
+ finished early due to the `eos_token_id`.
+ """
+
+ n_docs = n_docs if n_docs is not None else self.config.n_docs
+ do_deduplication = do_deduplication if do_deduplication is not None else self.config.do_deduplication
+ num_doc_return_sequences = (
+ num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
+ )
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
+
+ assert input_ids is not None or context_input_ids is not None, (
+ " At least one of input_ids or context_input_ids must be given"
+ )
+
+ if self.retriever is not None and context_input_ids is None:
+ question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
+ context_input_ids = self.retriever(
+ input_ids,
+ question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
+ prefix=self.generator.config.prefix,
+ n_docs=n_docs,
+ return_tensors="pt",
+ )["context_input_ids"]
+
+ # set to correct device
+ context_input_ids = context_input_ids.to(input_ids)
+
+ hypos = []
+ model_kwargs["num_beams"] = num_beams
+ model_kwargs["num_return_sequences"] = num_beams
+ model_kwargs["attention_mask"] = None
+
+ batch_size = input_ids.shape[0] if input_ids is not None else context_input_ids.shape[0] // n_docs
+
+ for index in range(batch_size):
+ # first, generate beams from documents:
+ generator_input_ids = context_input_ids[index * n_docs : (index + 1) * n_docs] # (n_docs, max_len)
+
+ output_sequences = self.generator.generate(
+ generator_input_ids,
+ **model_kwargs,
+ ) # n_docs * n_beam, tgt_len
+ if do_deduplication:
+ # do_deduplication, max_output_len
+ output_sequences = torch.stack(list({str(k.tolist()): k for k in output_sequences}.values()))
+
+ num_candidates = output_sequences.shape[
+ 0
+ ] # after deduplication, this number can be less than n_docs*n_beam
+
+ # then, run model forwards to get nll scores:
+ if input_ids is not None:
+ new_input_ids = input_ids[index : index + 1].repeat(num_candidates, 1)
+ outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
+ else: # input_ids is None, need context_input_ids/mask and doc_scores
+ assert context_attention_mask is not None, (
+ "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you"
+ " can set a retriever using the `set_retriever(...)` function."
+ )
+ assert doc_scores is not None, (
+ "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a"
+ " retriever using the `set_retriever(...)` function."
+ )
+
+ individual_input_ids = generator_input_ids.repeat(
+ num_candidates, 1
+ ) # (num_candidates*n_docs, max_len)
+
+ individual_attention_mask = context_attention_mask[index * n_docs : (index + 1) * n_docs]
+ individual_attention_mask = individual_attention_mask.repeat(num_candidates, 1)
+
+ individual_doc_scores = doc_scores[index : (index + 1), :] # doc_scores.shape = [batch, n_docs]
+ individual_doc_scores = individual_doc_scores.repeat(num_candidates, 1) # [num_candidates, n_docs]
+
+ outputs = self(
+ context_input_ids=individual_input_ids,
+ context_attention_mask=individual_attention_mask,
+ doc_scores=individual_doc_scores,
+ labels=output_sequences,
+ exclude_bos_score=True,
+ )
+
+ top_cand_inds = (-outputs["loss"]).topk(num_doc_return_sequences)[1]
+
+ # add hypothesis
+ hypos.append(output_sequences[top_cand_inds])
+
+ return self._cat_and_pad(hypos, pad_token_id=self.config.generator.pad_token_id)
+
+ def get_nll(
+ self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, exclude_bos_score=False, n_docs=None
+ ):
+ # shift tokens left
+ target = torch.cat(
+ [target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1
+ )
+
+ n_docs = n_docs if n_docs is not None else self.config.n_docs
+
+ # bos_token_id is None for T5
+ bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id
+ use_bos = bos_token_id is not None and target[:, 0].eq(bos_token_id).all()
+
+ def _mask_pads(ll, smooth_obj):
+ pad_mask = target.eq(self.config.generator.pad_token_id)
+ if pad_mask.any():
+ ll.masked_fill_(pad_mask, 0.0)
+ smooth_obj.masked_fill_(pad_mask, 0.0)
+ return ll.squeeze(-1), smooth_obj.squeeze(-1)
+
+ # seq_logits dim = (batch*n_docs, tgt_len , #vocabs)
+ seq_logprobs = nn.functional.log_softmax(seq_logits, dim=-1).view(
+ seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
+ ) # batch_size x n_docs x tgt_len x #vocab_size
+ doc_logprobs = nn.functional.log_softmax(doc_scores, dim=1).unsqueeze(-1).unsqueeze(-1)
+
+ # RAG-sequence marginalization
+ first_token_scores = seq_logprobs[:, :, :1, :]
+ second_token_scores = seq_logprobs[:, :, 1:2, :]
+ remainder = seq_logprobs[:, :, 2:, :]
+ rag_logprobs = torch.cat([first_token_scores, second_token_scores + doc_logprobs, remainder], dim=2)
+
+ # calculate loss
+ target = target.unsqueeze(1).unsqueeze(-1).repeat(1, n_docs, 1, 1)
+ assert target.dim() == rag_logprobs.dim()
+
+ ll = rag_logprobs.gather(dim=-1, index=target)
+ smooth_obj = rag_logprobs.sum(dim=-1, keepdim=True) # total sum of all (normalised) logits
+
+ ll, smooth_obj = _mask_pads(ll, smooth_obj)
+
+ # sum over tokens, exclude bos while scoring
+ ll = ll[:, :, 1:].sum(2) if exclude_bos_score and use_bos else ll.sum(2)
+ smooth_obj = smooth_obj.sum(2)
+ ll = ll.logsumexp(1) # logsumexp over docs
+ smooth_obj = smooth_obj.logsumexp(1)
+
+ nll_loss = -ll
+ smooth_loss = -smooth_obj
+
+ if reduce_loss:
+ nll_loss = nll_loss.sum()
+ smooth_loss = smooth_loss.sum()
+
+ eps_i = epsilon / rag_logprobs.size(-1)
+ loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
+ return loss
+
+ @staticmethod
+ def _cat_and_pad(tensors, pad_token_id):
+ output = tensors[0].new(sum(t.shape[0] for t in tensors), max(t.shape[1] for t in tensors)).fill_(pad_token_id)
+ ind = 0
+ for t in tensors:
+ output[ind : ind + t.shape[0], : t.shape[1]] = t
+ ind += t.shape[0]
+ return output
+
+
+@auto_docstring(
+ custom_intro="""
+ A RAG-token model implementation. It performs RAG-token specific marginalization in the forward pass.
+ """
+)
+class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
+ def __init__(
+ self,
+ config: Optional[PretrainedConfig] = None,
+ question_encoder: Optional[PreTrainedModel] = None,
+ generator: Optional[PreTrainedModel] = None,
+ retriever: Optional[RagRetriever] = None,
+ **kwargs,
+ ):
+ r"""
+ question_encoder (`PreTrainedModel`, *optional*):
+ The model responsible for encoding the question into hidden states for retrieval.
+ generator (`PreTrainedModel`, *optional*):
+ The model responsible for generating text based on retrieved documents.
+ retriever (`RagRetriever`, *optional*):
+ The component responsible for retrieving documents from a knowledge base given the encoded question.
+ """
+ assert config is not None or (question_encoder is not None and generator is not None), (
+ "Either a configuration or an encoder and a generator has to be provided."
+ )
+
+ if config is None:
+ config = RagConfig.from_question_encoder_generator_configs(
+ question_encoder.config, generator.config, **kwargs
+ )
+
+ super().__init__(config)
+
+ # instantiate model
+ self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever)
+
+ def set_retriever(self, retriever: RagRetriever):
+ self.rag.retriever = retriever
+
+ def set_context_encoder_for_training(self, ctx_encoder: PreTrainedModel):
+ self.rag.context_encoder_training = True
+ self.rag.ctx_encoder = ctx_encoder
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ doc_scores=None,
+ n_docs=None,
+ **kwargs,
+ ):
+ # Overwritten -- `do_marginalize` is explicitly set in the output
+
+ if past_key_values is not None:
+ # if past is defined use only last decoder_input_ids
+ decoder_input_ids = decoder_input_ids[:, -1:]
+
+ return {
+ "input_ids": None,
+ "encoder_outputs": encoder_outputs,
+ "doc_scores": doc_scores,
+ "context_attention_mask": attention_mask,
+ "decoder_input_ids": decoder_input_ids,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "do_marginalize": True,
+ "n_docs": n_docs,
+ }
+
+ @property
+ def retriever(self):
+ return self.rag.retriever
+
+ @property
+ def generator(self):
+ return self.rag.generator
+
+ @property
+ def question_encoder(self):
+ return self.rag.question_encoder
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ """Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
+
+ def _reorder_stacked(hidden_states, new_order):
+ n_docs = hidden_states.shape[0] // new_order.shape[0]
+ hidden_states = hidden_states.view(-1, n_docs, *hidden_states.shape[1:])
+ hidden_states = hidden_states.index_select(0, new_order)
+ result = hidden_states.view(-1, *hidden_states.shape[2:])
+ return result
+
+ reordered_past = ()
+ for layer_past in past_key_values:
+ # get the correct batch idx from decoder layer's batch dim for cross and self-attn
+ reordered_past += (
+ tuple(_reorder_stacked(past_state, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ if isinstance(past_key_values, EncoderDecoderCache):
+ reordered_past = EncoderDecoderCache.from_legacy_cache(reordered_past)
+
+ return reordered_past
+
+ def marginalize(self, seq_logits, doc_scores, n_docs=None):
+ n_docs = n_docs if n_docs is not None else self.config.n_docs
+
+ # RAG-token marginalization
+ seq_logprobs = nn.functional.log_softmax(seq_logits, dim=-1).view(
+ seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
+ )
+ doc_logprobs = torch.log_softmax(doc_scores, dim=1)
+ log_prob_sum = seq_logprobs + doc_logprobs.unsqueeze(-1).unsqueeze(-1)
+ return torch.logsumexp(log_prob_sum, dim=1)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ context_input_ids: Optional[torch.LongTensor] = None,
+ context_attention_mask: Optional[torch.LongTensor] = None,
+ doc_scores: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_retrieved: Optional[bool] = None,
+ do_marginalize: Optional[bool] = None,
+ reduce_loss: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ n_docs: Optional[int] = None,
+ **kwargs, # needs kwargs for generation
+ ) -> RetrievAugLMMarginOutput:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
+ which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
+ obtain the indices.
+
+ [What are input IDs?](../glossary#input-ids)
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
+ Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
+ *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
+ sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
+ generator's encoder.
+
+ Used by the ([`RagModel`]) model during decoding.
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Provide for generation tasks. `None` by default, construct as per instructions for the generator model
+ you're using with your RAG instance.
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
+ the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
+ context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
+ Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
+ provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
+ doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
+ Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
+ `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
+ has to be provided to the forward pass. `doc_scores` can be computed via
+ `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
+ output_retrieved (`bool`, *optional*):
+ Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
+ `context_attention_mask`. See returned tensors for more detail.
+ do_marginalize (`bool`, *optional*):
+ If `True`, the logits are marginalized over all documents by making use of
+ `torch.nn.functional.log_softmax`.
+ reduce_loss (`bool`, *optional*):
+ Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `torch.Tensor.sum`
+ operation.
+ n_docs (`int`, *optional*):
+ The number of documents to retrieve.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, RagRetriever, RagTokenForGeneration
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-nq")
+ >>> retriever = RagRetriever.from_pretrained(
+ ... "facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True
+ ... )
+ >>> # initialize with RagRetriever to do everything in one forward call
+ >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
+
+ >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
+ >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
+ >>> input_ids = inputs["input_ids"]
+ >>> labels = targets["input_ids"]
+ >>> outputs = model(input_ids=input_ids, labels=labels)
+
+ >>> # or use retriever separately
+ >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
+ >>> # 1. Encode
+ >>> question_hidden_states = model.question_encoder(input_ids)[0]
+ >>> # 2. Retrieve
+ >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
+ >>> doc_scores = torch.bmm(
+ ... question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)
+ ... ).squeeze(1)
+ >>> # 3. Forward to generator
+ >>> outputs = model(
+ ... context_input_ids=docs_dict["context_input_ids"],
+ ... context_attention_mask=docs_dict["context_attention_mask"],
+ ... doc_scores=doc_scores,
+ ... decoder_input_ids=labels,
+ ... )
+
+ >>> # or directly generate
+ >>> generated = model.generate(
+ ... context_input_ids=docs_dict["context_input_ids"],
+ ... context_attention_mask=docs_dict["context_attention_mask"],
+ ... doc_scores=doc_scores,
+ ... )
+ >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
+ ```"""
+ n_docs = n_docs if n_docs is not None else self.config.n_docs
+ do_marginalize = do_marginalize if do_marginalize is not None else self.config.do_marginalize
+ reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss
+
+ if labels is not None:
+ if decoder_input_ids is None:
+ decoder_input_ids = labels
+ use_cache = False
+
+ outputs = self.rag(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ encoder_outputs=encoder_outputs,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ context_input_ids=context_input_ids,
+ context_attention_mask=context_attention_mask,
+ doc_scores=doc_scores,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_retrieved=output_retrieved,
+ n_docs=n_docs,
+ )
+
+ loss = None
+ logits = outputs.logits
+ if labels is not None:
+ assert decoder_input_ids is not None
+ loss = self.get_nll(
+ outputs.logits,
+ outputs.doc_scores,
+ labels,
+ reduce_loss=reduce_loss,
+ epsilon=self.config.label_smoothing,
+ n_docs=n_docs,
+ )
+
+ if do_marginalize:
+ logits = self.marginalize(logits, outputs.doc_scores, n_docs)
+
+ return RetrievAugLMMarginOutput(
+ loss=loss,
+ logits=logits,
+ doc_scores=outputs.doc_scores,
+ past_key_values=outputs.past_key_values,
+ context_input_ids=outputs.context_input_ids,
+ context_attention_mask=outputs.context_attention_mask,
+ retrieved_doc_embeds=outputs.retrieved_doc_embeds,
+ retrieved_doc_ids=outputs.retrieved_doc_ids,
+ question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state,
+ question_enc_hidden_states=outputs.question_enc_hidden_states,
+ question_enc_attentions=outputs.question_enc_attentions,
+ generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state,
+ generator_enc_hidden_states=outputs.generator_enc_hidden_states,
+ generator_enc_attentions=outputs.generator_enc_attentions,
+ generator_dec_hidden_states=outputs.generator_dec_hidden_states,
+ generator_dec_attentions=outputs.generator_dec_attentions,
+ generator_cross_attentions=outputs.generator_cross_attentions,
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ context_input_ids: Optional[torch.LongTensor] = None,
+ context_attention_mask: Optional[torch.LongTensor] = None,
+ doc_scores: Optional[torch.FloatTensor] = None,
+ n_docs: Optional[int] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
+ logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
+ stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
+ **kwargs,
+ ) -> torch.LongTensor:
+ """
+ Implements RAG token decoding.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ The sequence used as a prompt for the generation. If `input_ids` is not passed, then
+ `context_input_ids` has to be provided.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever.
+
+ If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
+ forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
+ context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever.
+
+ If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
+ forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
+ doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
+ Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
+ `question_encoder_last_hidden_state`.
+
+ If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
+ forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
+ n_docs (`int`, *optional*, defaults to `config.n_docs`)
+ Number of documents to retrieve and/or number of documents for which to generate an answer.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which has the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
+ If provided, this function constraints the beam search to allowed tokens only at each step. If not
+ provided no constraint is applied. This function takes 2 arguments `inputs_ids` and the batch ID
+ `batch_id`. It has to return a list with the allowed tokens for the next generation step conditioned on
+ the previously generated tokens `inputs_ids` and the batch ID `batch_id`. This argument is useful for
+ constrained generation conditioned on the prefix, as described in [Autoregressive Entity
+ Retrieval](https://huggingface.co/papers/2010.00904).
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and a
+ model's config. If a logit processor is passed that is already created with the arguments or a model's
+ config an error is thrown.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
+ model's config. If a stopping criteria is passed that is already created with the arguments or a
+ model's config an error is thrown.
+ kwargs (`dict[str, Any]`, *optional*):
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model.
+
+ Return:
+ `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
+ sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches
+ finished early due to the `eos_token_id`.
+ """
+ # Handle `generation_config` and kwargs that might update it
+ if generation_config is None:
+ generation_config = self.generation_config
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
+
+ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
+ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask)
+
+ # set default parameters
+ n_docs = n_docs if n_docs is not None else self.config.n_docs
+
+ # retrieve docs
+ if self.retriever is not None and context_input_ids is None:
+ question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
+ out = self.retriever(
+ input_ids,
+ question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
+ prefix=self.generator.config.prefix,
+ n_docs=n_docs,
+ return_tensors="pt",
+ )
+ context_input_ids, context_attention_mask, retrieved_doc_embeds = (
+ out["context_input_ids"],
+ out["context_attention_mask"],
+ out["retrieved_doc_embeds"],
+ )
+
+ # set to correct device
+ retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states)
+ context_input_ids = context_input_ids.to(input_ids)
+ context_attention_mask = context_attention_mask.to(input_ids)
+
+ # compute doc_scores
+ doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze(
+ 1
+ )
+
+ assert (context_input_ids.shape[0] % n_docs) == 0, (
+ f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
+ f" {context_input_ids.shape[0]}."
+ )
+
+ # batch_size
+ batch_size = context_input_ids.shape[0] // n_docs
+
+ encoder = self.rag.generator.get_encoder()
+ encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
+
+ input_ids = torch.full(
+ (batch_size * generation_config.num_beams, 1),
+ generation_config.decoder_start_token_id,
+ dtype=torch.long,
+ device=next(self.parameters()).device,
+ )
+ input_ids_seq_length = input_ids.shape[-1]
+ last_hidden_state = encoder_outputs["last_hidden_state"]
+
+ def extend_enc_output(tensor, num_beams=None):
+ # split into `batch_size`, `num_beams`, `num_docs`
+ tensor = tensor[None, None, :].reshape((batch_size, 1, n_docs) + tensor.shape[1:])
+ # repeat same last hidden states over `num_beams` dimension
+ tensor = tensor.expand((batch_size, num_beams, n_docs) + tensor.shape[3:])
+ # merge `batch_size`, `num_beams`, `num_docs` dims again
+ return tensor.reshape((batch_size * num_beams * n_docs,) + tensor.shape[3:])
+
+ # correctly extend last_hidden_state and attention mask
+ context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams)
+ encoder_outputs["last_hidden_state"] = extend_enc_output(
+ last_hidden_state, num_beams=generation_config.num_beams
+ )
+
+ doc_scores = doc_scores.repeat_interleave(generation_config.num_beams, dim=0)
+
+ # define start_len & additional parameters
+ model_kwargs["doc_scores"] = doc_scores
+ model_kwargs["encoder_outputs"] = encoder_outputs
+ model_kwargs["attention_mask"] = context_attention_mask
+ model_kwargs["n_docs"] = n_docs
+
+ pre_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=context_input_ids,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ logits_processor=logits_processor,
+ device=input_ids.device,
+ )
+
+ prepared_stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+
+ self._prepare_cache_for_generation(
+ generation_config,
+ model_kwargs,
+ generation_mode=None,
+ batch_size=input_ids.shape[0],
+ max_cache_length=generation_config.max_length - 1,
+ )
+
+ if generation_config.num_beams == 1:
+ if generation_config.num_return_sequences > 1:
+ raise ValueError(
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
+ " greedy search."
+ )
+ return self._sample(
+ input_ids,
+ logits_processor=pre_processor,
+ stopping_criteria=prepared_stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=False,
+ streamer=None,
+ **model_kwargs,
+ )
+ elif generation_config.num_beams > 1:
+ if generation_config.num_return_sequences > generation_config.num_beams:
+ raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
+
+ return self._beam_search(
+ input_ids,
+ logits_processor=pre_processor,
+ stopping_criteria=prepared_stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=False,
+ **model_kwargs,
+ )
+ else:
+ raise ValueError(
+ f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}"
+ )
+
+ # Auxiliary functions for beam search
+ def _temporary_reorder_cache(self, past_key_values, beam_idx):
+ # RAG should always use the legacy path even though the LM backbone (T5) uses new cache format
+ # because RAG expands input for doc-size internally. TODO: raushan, remove me when all models support
+ # new cache format
+ past_key_values = self._reorder_cache(past_key_values, beam_idx)
+ return past_key_values
+
+ def get_input_embeddings(self):
+ return self.rag.generator.get_input_embeddings()
+
+ def get_output_embeddings(self):
+ return self.rag.generator.get_output_embeddings()
+
+ def set_output_embeddings(self, new_embeddings):
+ return self.rag.generator.set_output_embeddings(new_embeddings)
+
+ def shift_tokens_right(self, input_ids, start_token_id=None):
+ """Shift input ids one token to the right, and pad with start_token_id"""
+ if start_token_id is None:
+ start_token_id = self.config.decoder_start_token_id
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ shifted_input_ids[:, 0] = start_token_id
+ return shifted_input_ids
+
+ def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, n_docs=None):
+ n_docs = n_docs if n_docs is not None else self.config.n_docs
+ # shift tokens left
+ target = torch.cat(
+ [target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1
+ )
+
+ def _mask_pads(ll, smooth_obj):
+ pad_mask = target.eq(self.config.generator.pad_token_id)
+ if pad_mask.any():
+ ll.masked_fill_(pad_mask, 0.0)
+ smooth_obj.masked_fill_(pad_mask, 0.0)
+ return ll.squeeze(-1), smooth_obj.squeeze(-1)
+
+ rag_logprobs = self.marginalize(seq_logits, doc_scores, n_docs)
+
+ target = target.unsqueeze(-1)
+ assert target.dim() == rag_logprobs.dim()
+
+ ll = rag_logprobs.gather(dim=-1, index=target)
+ smooth_obj = rag_logprobs.sum(dim=-1, keepdim=True) # total sum of all (normalised) logits
+ ll, smooth_obj = _mask_pads(ll, smooth_obj)
+ ll = ll.sum(1) # sum over tokens
+ smooth_obj = smooth_obj.sum(1)
+
+ nll_loss = -ll
+ smooth_loss = -smooth_obj
+
+ if reduce_loss:
+ nll_loss = nll_loss.sum()
+ smooth_loss = smooth_loss.sum()
+
+ eps_i = epsilon / rag_logprobs.size(-1)
+ loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
+ return loss
+
+
+__all__ = ["RagModel", "RagPreTrainedModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/modeling_tf_rag.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/modeling_tf_rag.py
new file mode 100644
index 0000000000000000000000000000000000000000..15538377287124a10ba9a19055d3f84bf5897144
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/modeling_tf_rag.py
@@ -0,0 +1,1776 @@
+# coding=utf-8
+# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""TFRAG model implementation."""
+
+from __future__ import annotations
+
+import copy
+from dataclasses import dataclass
+
+import numpy as np
+import tensorflow as tf
+
+from ...configuration_utils import PretrainedConfig
+from ...generation import TFLogitsProcessorList
+from ...modeling_tf_utils import (
+ TFCausalLanguageModelingLoss,
+ TFModelInputType,
+ TFPreTrainedModel,
+ keras,
+ shape_list,
+ unpack_inputs,
+)
+from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from .configuration_rag import RagConfig
+from .retrieval_rag import RagRetriever
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "RagConfig"
+
+
+@dataclass
+class TFRetrievAugLMMarginOutput(ModelOutput):
+ """
+ Base class for retriever augmented marginalized models outputs.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
+ each vocabulary token.
+ past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
+ (see `past_key_values` input) to speed up sequential decoding.
+ doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`):
+ Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
+ `question_encoder_last_hidden_state`.
+ retrieved_doc_embeds (`tf.Tensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
+ Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
+ the `doc_scores`.
+ retrieved_doc_ids (`tf.Tensor` (int32) of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
+ The indexes of the embedded documents retrieved by the retriever.
+ context_input_ids (`tf.Tensor`(int32) of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
+ context_attention_mask (`tf.Tensor` (int32) of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever.
+ question_encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
+ model.
+ question_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
+ question_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
+ average in the self-attention heads.
+ generator_enc_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
+ generator_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
+ generator_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
+ average in the self-attention heads.
+ generator_dec_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
+ generator_dec_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
+ average in the self-attention heads.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ past_key_values: list[tf.Tensor] | None = None
+ doc_scores: tf.Tensor | None = None
+ retrieved_doc_embeds: tf.Tensor | None = None
+ retrieved_doc_ids: tf.Tensor | None = None
+ context_input_ids: tf.Tensor | None = None
+ context_attention_mask: tf.Tensor | None = None
+ question_encoder_last_hidden_state: tf.Tensor | None = None
+ question_enc_hidden_states: tuple[tf.Tensor, ...] | None = None
+ question_enc_attentions: tuple[tf.Tensor, ...] | None = None
+ generator_enc_last_hidden_state: tf.Tensor | None = None
+ generator_enc_hidden_states: tuple[tf.Tensor, ...] | None = None
+ generator_enc_attentions: tuple[tf.Tensor, ...] | None = None
+ generator_dec_hidden_states: tuple[tf.Tensor, ...] | None = None
+ generator_dec_attentions: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFRetrievAugLMOutput(ModelOutput):
+ """
+ Args:
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
+ each vocabulary token.
+ past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
+ (see `past_key_values` input) to speed up sequential decoding.
+ doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`):
+ Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
+ `question_encoder_last_hidden_state`.
+ retrieved_doc_embeds (`tf.Tensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
+ Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
+ the `doc_scores`.
+ retrieved_doc_ids (`tf.Tensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
+ The indexes of the embedded documents retrieved by the retriever.
+ context_input_ids (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
+ context_attention_mask (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever.
+ question_encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
+ model.
+ question_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
+ question_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
+ average in the self-attention heads.
+ generator_enc_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
+ generator_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
+ generator_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
+ average in the self-attention heads.
+ generator_dec_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
+ generator_dec_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
+ average in the self-attention heads.
+ """
+
+ logits: tf.Tensor | None = None
+ past_key_values: list[tf.Tensor] | None = None
+ doc_scores: tf.Tensor | None = None
+ retrieved_doc_embeds: tf.Tensor | None = None
+ retrieved_doc_ids: tf.Tensor | None = None
+ context_input_ids: tf.Tensor | None = None
+ context_attention_mask: tf.Tensor | None = None
+ question_encoder_last_hidden_state: tf.Tensor | None = None
+ question_enc_hidden_states: tuple[tf.Tensor, ...] | None = None
+ question_enc_attentions: tuple[tf.Tensor, ...] | None = None
+ generator_enc_last_hidden_state: tf.Tensor | None = None
+ generator_enc_hidden_states: tuple[tf.Tensor, ...] | None = None
+ generator_enc_attentions: tuple[tf.Tensor, ...] | None = None
+ generator_dec_hidden_states: tuple[tf.Tensor, ...] | None = None
+ generator_dec_attentions: tuple[tf.Tensor, ...] | None = None
+
+
+class TFRagPreTrainedModel(TFPreTrainedModel):
+ r"""
+ RAG models were released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP
+ Tasks](https://huggingface.co/papers/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.
+
+ RAG is a retriever augmented model and encapsulate three components: a question encoder, a dataset retriever and a
+ generator, the encoder and generator are trainable while the retriever is just an indexed dataset.
+
+ """
+
+ config_class = RagConfig
+ base_model_prefix = "rag"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ @classmethod
+ def from_pretrained_question_encoder_generator(
+ cls,
+ question_encoder_pretrained_model_name_or_path: str | None = None,
+ generator_pretrained_model_name_or_path: str | None = None,
+ retriever: RagRetriever = None,
+ *model_args,
+ **kwargs,
+ ) -> TFPreTrainedModel:
+ r"""
+ Instantiates an question encoder and a generator from one or two base classes of the library from pretrained
+ model checkpoints.
+
+ Params:
+ question_encoder_pretrained_model_name_or_path (`str`, *optional*):
+ Information necessary to initiate the question encoder. Can be either:
+
+ - A string with the *shortcut name* of a pretrained model to load from cache or download, e.g.,
+ `google-bert/bert-base-uncased`.
+ - A string with the *identifier name* of a pretrained model that was user-uploaded to our S3, e.g.,
+ `dbmdz/bert-base-german-cased`.
+ - A path to a *directory* containing model weights saved using
+ [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *pytorch index checkpoint file* (e.g, `./pt_model/`). In this case,
+ `question_encoder_from_pt` should be set to `True`.
+
+ generator_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
+ Information necessary to initiate the generator. Can be either:
+
+ - A string with the *shortcut name* of a pretrained model to load from cache or download, e.g.,
+ `google-t5/t5-small`.
+ - A string with the *identifier name* of a pretrained model that was user-uploaded to our S3, e.g.,
+ `facebook/bart-base`.
+ - A path to a *directory* containing model weights saved using
+ [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *pytorch checkpoint file* (e.g, `./pt_model/`). In this case,
+ `generator_from_pt` should be set to `True`.
+
+ model_args (remaining positional arguments, *optional*):
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
+ retriever ([`RagRetriever`], *optional*):
+ The retriever to use.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`).
+
+ - To update the question_encoder configuration, use the prefix *question_encoder_* for each
+ configuration parameter.
+ - To update the generator configuration, use the prefix *generator_* for each configuration parameter.
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+ Example:
+
+ ```python
+ >>> from transformers import RagRetriever, TFRagModel
+
+ >>> # initialize a RAG from two pretrained models.
+ >>> model = TFRagModel.from_pretrained_question_encoder_generator(
+ ... "facebook/dpr-question_encoder-single-nq-base", "google-t5/t5-small"
+ ... )
+ >>> # alternatively, initialize from pytorch pretrained models can also be done
+ >>> model = TFRagModel.from_pretrained_question_encoder_generator(
+ ... "facebook/dpr-question_encoder-single-nq-base",
+ ... "facebook/bart-base",
+ ... generator_from_pt=True,
+ ... question_encoder_from_pt=True,
+ ... )
+
+ >>> # saving model after fine-tuning
+ >>> model.save_pretrained("./rag")
+
+ >>> # load retriever
+ >>> retriever = RagRetriever.from_pretrained(
+ ... "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True
+ ... )
+ >>> # load fine-tuned model with retriever
+ >>> model = TFRagModel.from_pretrained("./rag", retriever=retriever)
+ ```"""
+
+ kwargs_question_encoder = {
+ argument[len("question_encoder_") :]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("question_encoder_")
+ }
+
+ kwargs_generator = {
+ argument[len("generator_") :]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("generator_")
+ }
+
+ # remove question_encoder, generator kwargs from kwargs
+ for key in kwargs_question_encoder:
+ del kwargs["question_encoder_" + key]
+ for key in kwargs_generator:
+ del kwargs["generator_" + key]
+
+ # Load and initialize the question_encoder and generator
+ # The distinction between question_encoder and generator at the model level is made
+ # by the value of the flag `is_generator` that we need to set correctly.
+ question_encoder = kwargs_question_encoder.pop("model", None)
+ if question_encoder is None:
+ assert question_encoder_pretrained_model_name_or_path is not None, (
+ "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to"
+ " be defined"
+ )
+
+ from ..auto.modeling_tf_auto import TFAutoModel
+
+ if "config" not in kwargs_question_encoder:
+ from ..auto.configuration_auto import AutoConfig
+
+ question_encoder_config = AutoConfig.from_pretrained(question_encoder_pretrained_model_name_or_path)
+ kwargs_question_encoder["config"] = question_encoder_config
+
+ question_encoder = TFAutoModel.from_pretrained(
+ question_encoder_pretrained_model_name_or_path,
+ name="question_encoder",
+ load_weight_prefix=cls.load_weight_prefix,
+ *model_args,
+ **kwargs_question_encoder,
+ )
+
+ generator = kwargs_generator.pop("generator", None)
+ if generator is None:
+ assert generator_pretrained_model_name_or_path is not None, (
+ "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has"
+ " to be defined"
+ )
+
+ from ..auto.modeling_tf_auto import TFAutoModelForSeq2SeqLM
+
+ if "config" not in kwargs_generator:
+ from ..auto.configuration_auto import AutoConfig
+
+ generator_config = AutoConfig.from_pretrained(generator_pretrained_model_name_or_path)
+ kwargs_generator["config"] = generator_config
+
+ generator = TFAutoModelForSeq2SeqLM.from_pretrained(
+ generator_pretrained_model_name_or_path,
+ name="generator",
+ load_weight_prefix=cls.load_weight_prefix,
+ **kwargs_generator,
+ )
+
+ # instantiate config with corresponding kwargs
+ config = kwargs.get("config")
+ if config is None:
+ config = RagConfig.from_question_encoder_generator_configs(
+ question_encoder.config, generator.config, **kwargs
+ )
+
+ return cls(question_encoder=question_encoder, generator=generator, config=config, retriever=retriever)
+
+
+RAG_START_DOCSTRING = r"""
+
+ RAG is a sequence-to-sequence model which encapsulates two core components: a question encoder and a generator.
+ During a forward pass, we encode the input with the question encoder and pass it to the retriever to extract
+ relevant context documents. The documents are then prepended to the input. Such contextualized inputs is passed to
+ the generator.
+
+ The question encoder can be any *autoencoding* model, preferably [`TFDPRQuestionEncoder`], and the generator can be
+ any *seq2seq* model, preferably [`TFBartForConditionalGeneration`].
+
+ The model can be initialized with a [`RagRetriever`] for end-to-end generation or used in combination with the
+ outputs of a retriever in multiple steps---see examples for more details. The model is compatible any
+ *autoencoding* model as the `question_encoder` and any *seq2seq* model with language model head as the `generator`.
+ It has been tested with [`TFDPRQuestionEncoder`] as the `question_encoder` and [`TFBartForConditionalGeneration`]
+ as the `generator`.
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Tensorflow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
+ subclass. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to
+ general usage and behavior.
+
+ The model is in a developing state as it is now fully supports in eager-mode only, and may not be exported in
+ SavedModel format.
+
+ Args:
+ config ([`RagConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+ question_encoder ([`TFPreTrainedModel`]):
+ An encoder model compatible with the faiss index encapsulated by the `retriever`.
+ generator ([`TFPreTrainedModel`]):
+ A seq2seq model used as the generator in the RAG architecture.
+ retriever ([`RagRetriever`]):
+ A retriever class encapsulating a faiss index queried to obtain context documents for current inputs.
+"""
+
+
+RAG_FORWARD_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
+ which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
+ obtain the indices.
+ attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*)
+ Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
+ *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
+ sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
+ generator's encoder.
+
+ Used by the ([`TFRagModel`]) model during decoding.
+ decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Provide for generation tasks. `None` by default, construct as per instructions for the generator model
+ you're using with your RAG instance.
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ past_key_values (`tuple(tuple(tf.Tensor))`):
+ Tuple consists of two elements: `encoder_outputs` of the RAG model (see `encoder_outputs`) and
+ `past_key_values` of the underlying generator. Can be used to speed up decoding. `past_key_values` are used
+ in the ([`RagTokenForGeneration`]) model during decoding.
+ doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`):
+ Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
+ `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
+ has to be provided to the forward pass. `doc_scores` can be computed via
+ `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
+ context_input_ids (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever.
+
+ If the model has is not initialized with a `retriever` ``context_input_ids` has to be provided to the
+ forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. context_attention_mask
+ (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when
+ *output_retrieved=True*): Attention mask post-processed from the retrieved documents and the question
+ encoder `input_ids` by the retriever.
+
+ If the model has is not initialized with a `retriever` `context_attention_mask` has to be provided to the
+ forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
+ use_cache (`bool`, *optional*, defaults to `True`):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ output_retrieved(`bool`, *optional*):
+ Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
+ `context_attention_mask`. See returned tensors for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`TFRetrievAugLMOutput`] instead of a plain tuple.
+ n_docs (`int`, *optional*, defaults to `config.n_docs``)
+ Number of documents to retrieve and/or number of documents for which to generate an answer.
+"""
+
+
+@add_start_docstrings_to_model_forward(RAG_START_DOCSTRING)
+class TFRagModel(TFRagPreTrainedModel):
+ load_weight_prefix = "tf_rag_model_1"
+
+ def __init__(
+ self,
+ config: PretrainedConfig | None = None,
+ question_encoder: TFPreTrainedModel | None = None,
+ generator: TFPreTrainedModel | None = None,
+ retriever: RagRetriever | None = None,
+ load_weight_prefix: str | None = None,
+ **kwargs,
+ ):
+ assert config is not None or (question_encoder is not None and generator is not None), (
+ "Either a configuration or an question_encoder and a generator has to be provided."
+ )
+
+ if config is None:
+ config = RagConfig.from_question_encoder_generator_configs(
+ question_encoder.config, generator.config, **kwargs
+ )
+ else:
+ assert isinstance(config, self.config_class), f"config: {config} has to be of type {self.config_class}"
+ super().__init__(config, **kwargs)
+
+ if question_encoder is None:
+ from ..auto.modeling_tf_auto import TFAutoModel
+
+ question_encoder = TFAutoModel.from_config(config.question_encoder, name="question_encoder")
+
+ if generator is None:
+ from ..auto.modeling_tf_auto import TFAutoModelForSeq2SeqLM
+
+ load_weight_prefix = load_weight_prefix if load_weight_prefix is not None else self.load_weight_prefix
+ generator = TFAutoModelForSeq2SeqLM.from_config(
+ config.generator, name="generator", load_weight_prefix=load_weight_prefix + "/generator"
+ )
+
+ self.retriever = retriever
+ if self.retriever is not None:
+ assert isinstance(retriever, RagRetriever), (
+ f"`self.retriever` is of type {type(self.retriever)}, but should be of type `RagRetriever`"
+ )
+ self.retriever = retriever
+
+ self.question_encoder = question_encoder
+ self.generator = generator
+
+ def set_retriever(self, retriever: RagRetriever):
+ self.retriever = retriever
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFRetrievAugLMOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ encoder_outputs: np.ndarray | tf.Tensor | None = None,
+ decoder_input_ids: np.ndarray | tf.Tensor | None = None,
+ decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ doc_scores: np.ndarray | tf.Tensor | None = None,
+ context_input_ids: np.ndarray | tf.Tensor | None = None,
+ context_attention_mask: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ output_retrieved: bool | None = None,
+ n_docs: int | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ **kwargs,
+ ) -> TFRetrievAugLMOutput:
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, RagRetriever, TFRagModel
+ >>> import torch
+ from ...utils.deprecation import deprecate_kwarg
+ from ...utils.deprecation import deprecate_kwarg
+ from ...utils.deprecation import deprecate_kwarg
+ from ...utils.deprecation import deprecate_kwarg
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-base")
+ >>> retriever = RagRetriever.from_pretrained(
+ ... "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True
+ ... )
+ >>> # initialize with RagRetriever to do everything in one forward call
+ >>> model = TFRagModel.from_pretrained("facebook/rag-token-base", retriever=retriever, from_pt=True)
+
+ >>> input_dict = tokenizer.prepare_seq2seq_batch(
+ ... "How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf"
+ ... )
+ >>> input_ids = input_dict["input_ids"]
+ >>> outputs = model(input_ids)
+ ```"""
+ assert "decoder_cached_states" not in kwargs, (
+ "Please use past_key_values to cache intermediate outputs"
+ ) # from modeling_tf_bart.py
+
+ # aliasing to minimize code changing
+ n_docs = n_docs if n_docs is not None else self.config.n_docs
+
+ # whether retriever has to be used
+ has_to_retrieve = (
+ self.retriever is not None
+ and (context_input_ids is None or context_attention_mask is None or doc_scores is None)
+ and encoder_outputs is None
+ )
+
+ # encoder_outputs are pre-computed during RAG-token generation
+ if encoder_outputs is None:
+ if has_to_retrieve:
+ question_enc_outputs = self.question_encoder(
+ input_ids, attention_mask=attention_mask, return_dict=True, training=training
+ )
+ # see https://github.com/huggingface/transformers/blob/main/src/transformers/models/dpr/modeling_tf_dpr.py#L91
+ question_encoder_last_hidden_state = question_enc_outputs[
+ 0
+ ] # hidden states of question encoder => pooler_output
+
+ retriever_outputs = self.retriever(
+ input_ids,
+ question_encoder_last_hidden_state.numpy(),
+ prefix=self.generator.config.prefix,
+ n_docs=n_docs,
+ return_tensors="tf",
+ )
+ context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = (
+ retriever_outputs["context_input_ids"],
+ retriever_outputs["context_attention_mask"],
+ retriever_outputs["retrieved_doc_embeds"],
+ retriever_outputs["doc_ids"],
+ )
+
+ context_input_ids = tf.cast(context_input_ids, tf.int32)
+ context_attention_mask = tf.cast(context_attention_mask, tf.int32)
+ retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32)
+ retrieved_doc_ids = tf.cast(retrieved_doc_ids, tf.int32)
+
+ # compute doc_scores
+ doc_scores = tf.squeeze(
+ tf.matmul(
+ tf.expand_dims(question_encoder_last_hidden_state, axis=1),
+ retrieved_doc_embeds,
+ transpose_b=True,
+ ),
+ axis=1,
+ )
+
+ else:
+ assert context_input_ids is not None, (
+ "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can"
+ " set a retriever using the `set_retriever(...)` function."
+ )
+ assert context_attention_mask is not None, (
+ "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you"
+ " can set a retriever using the `set_retriever(...)` function."
+ )
+ assert doc_scores is not None, (
+ "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a"
+ " retriever using the `set_retriever(...)` function."
+ )
+
+ assert doc_scores is not None, (
+ "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function."
+ )
+
+ assert (doc_scores.shape[1] % n_docs) == 0, (
+ f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
+ f" {context_input_ids.shape[0]}."
+ )
+
+ # Decoder input without context documents
+ if decoder_input_ids is not None:
+ decoder_input_ids = tf.repeat(decoder_input_ids, n_docs, axis=0)
+
+ if decoder_attention_mask is not None:
+ decoder_attention_mask = tf.repeat(decoder_attention_mask, n_docs, axis=0)
+
+ gen_outputs = self.generator(
+ context_input_ids,
+ attention_mask=context_attention_mask,
+ encoder_outputs=encoder_outputs,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ return_dict=True,
+ training=training,
+ )
+
+ if not has_to_retrieve:
+ question_encoder_last_hidden_state = None
+ question_enc_hidden_states = None
+ question_enc_attentions = None
+ retrieved_doc_embeds = None
+ retrieved_doc_ids = None
+ else:
+ question_enc_hidden_states = question_enc_outputs.hidden_states
+ question_enc_attentions = question_enc_outputs.attentions
+
+ if not has_to_retrieve or not output_retrieved:
+ # don't output retrieved docs
+ context_input_ids = (None,)
+ context_attention_mask = None
+ retrieved_doc_embeds = None
+ retrieved_doc_ids = None
+
+ return TFRetrievAugLMOutput(
+ logits=gen_outputs.logits,
+ doc_scores=doc_scores,
+ past_key_values=gen_outputs.past_key_values,
+ context_input_ids=context_input_ids,
+ context_attention_mask=context_attention_mask,
+ retrieved_doc_embeds=retrieved_doc_embeds,
+ retrieved_doc_ids=retrieved_doc_ids,
+ question_encoder_last_hidden_state=question_encoder_last_hidden_state,
+ question_enc_hidden_states=question_enc_hidden_states,
+ question_enc_attentions=question_enc_attentions,
+ generator_enc_last_hidden_state=gen_outputs.encoder_last_hidden_state,
+ generator_enc_hidden_states=gen_outputs.encoder_hidden_states,
+ generator_enc_attentions=gen_outputs.encoder_attentions,
+ generator_dec_hidden_states=gen_outputs.decoder_hidden_states,
+ generator_dec_attentions=gen_outputs.decoder_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ with tf.name_scope(self.generator.name):
+ self.generator.build(None)
+ with tf.name_scope(self.question_encoder.name):
+ self.question_encoder.build(None)
+
+
+@add_start_docstrings_to_model_forward(
+ """
+ A TF RAG-token model implementation. It performs RAG-token specific marginalization in the forward pass.
+ """,
+ RAG_START_DOCSTRING,
+)
+class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss):
+ load_weight_prefix = "tf_rag_token_for_generation_1/rag"
+
+ def __init__(
+ self,
+ config: PretrainedConfig | None = None,
+ question_encoder: TFPreTrainedModel | None = None,
+ generator: TFPreTrainedModel | None = None,
+ retriever: RagRetriever | None = None,
+ **kwargs,
+ ):
+ assert config is not None or (question_encoder is not None and generator is not None), (
+ "Either a configuration or an encoder and a generator has to be provided."
+ )
+
+ if config is None:
+ config = RagConfig.from_question_encoder_generator_configs(
+ question_encoder.config, generator.config, **kwargs
+ )
+
+ super().__init__(config)
+
+ # instantiate model
+ self.rag = TFRagModel(
+ config=config,
+ question_encoder=question_encoder,
+ generator=generator,
+ retriever=retriever,
+ load_weight_prefix=self.load_weight_prefix,
+ name="rag",
+ )
+
+ def set_retriever(self, retriever: RagRetriever):
+ self.rag.retriever = retriever
+
+ # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_bart.py
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ doc_scores=None,
+ n_docs=None,
+ **kwargs,
+ ):
+ if past_key_values is not None:
+ # if past is defined use only last decoder_input_ids
+ decoder_input_ids = decoder_input_ids[:, -1:]
+
+ return {
+ "input_ids": None,
+ "encoder_outputs": encoder_outputs,
+ "doc_scores": doc_scores,
+ "context_attention_mask": attention_mask,
+ "decoder_input_ids": decoder_input_ids,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "do_marginalize": True,
+ "n_docs": n_docs,
+ }
+
+ @property
+ def retriever(self):
+ return self.rag.retriever
+
+ @property
+ def generator(self):
+ return self.rag.generator
+
+ @property
+ def question_encoder(self):
+ return self.rag.question_encoder
+
+ @staticmethod
+ def _gather_beams(nested, beam_indices, batch_axis=0):
+ """
+ RAG-specific `_gather_beams`: gathers the beam slices indexed by beam_indices into new beam array. If the
+ nested tensor has a shape mismatch with the beam indices, then it means it is the cache. In that case, isolates
+ and takes care of the extra dimension for ndocs.
+ """
+
+ def gather_fn(tensor):
+ is_rag_cache = tensor.shape[0] != beam_indices.shape[0]
+ if is_rag_cache:
+ n_docs = tensor.shape[0] // beam_indices.shape[0]
+ batch_size = beam_indices.shape[0]
+ # reshapes into (batch size, num beams, n_docs, ...), the cache format expected by RAG
+ tensor = tf.reshape(tensor, (batch_size, -1, n_docs, *tensor.shape[2:]))
+
+ gathered_tensor = tf.gather(params=tensor, indices=beam_indices, axis=1, batch_dims=1)
+
+ if is_rag_cache:
+ # reshapes back into the shape expected by beam search
+ gathered_tensor = tf.reshape(gathered_tensor, (batch_size * n_docs, -1, *gathered_tensor.shape[3:]))
+
+ return gathered_tensor
+
+ return tf.nest.map_structure(gather_fn, nested)
+
+ def marginalize(self, seq_logits, doc_scores, n_docs=None):
+ n_docs = n_docs if n_docs is not None else self.config.n_docs
+
+ # RAG-token marginalization
+ seq_logprobs = tf.nn.log_softmax(seq_logits, axis=-1)
+ seq_logprobs = tf.reshape(seq_logprobs, [seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.shape[-1]])
+ doc_logprobs = tf.nn.log_softmax(doc_scores, axis=1)
+ doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1)
+ doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1) # twice
+ log_prob_sum = seq_logprobs + doc_logprobs
+ return tf.reduce_logsumexp(log_prob_sum, axis=1)
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_input_ids: np.ndarray | tf.Tensor | None = None,
+ decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ encoder_outputs: np.ndarray | tf.Tensor | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ doc_scores: np.ndarray | tf.Tensor | None = None,
+ context_input_ids: np.ndarray | tf.Tensor | None = None,
+ context_attention_mask: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ output_retrieved: bool | None = None,
+ n_docs: int | None = None,
+ do_marginalize: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ reduce_loss: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ **kwargs, # needs kwargs for generation
+ ) -> TFRetrievAugLMMarginOutput:
+ r"""
+ do_marginalize (`bool`, *optional*):
+ If `True`, the logits are marginalized over all documents by making use of
+ `torch.nn.functional.log_softmax`.
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the cross entropy classification loss according to Rag-Token model formulation See
+ https://huggingface.co/papers/2005.11401 Section 2.1 for details about Rag-Token formulation. Indices should be
+ in `[0, ..., config.vocab_size - 1]`.
+ reduce_loss (`bool`, *optional*):
+ Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `tf.Tensor.sum`
+ operation.
+ kwargs (`dict[str, any]`, *optional*, defaults to `{}`):
+ Legacy dictionary, which is required so that model can use *generate()* function.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import tensorflow as tf
+ >>> from transformers import AutoTokenizer, RagRetriever, TFRagTokenForGeneration
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-nq")
+ >>> retriever = RagRetriever.from_pretrained(
+ ... "facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True
+ ... )
+ >>> # initialize with RagRetriever to do everything in one forward call
+ >>> model = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever, from_pt=True)
+
+ >>> input_dict = tokenizer.prepare_seq2seq_batch(
+ ... "How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf"
+ ... )
+ >>> outputs = model(input_dict, output_retrieved=True)
+
+ >>> # or use retriever separately
+ >>> # 1. Encode
+ >>> input_ids = input_dict["input_ids"]
+ >>> question_hidden_states = model.question_encoder(input_ids)[0]
+ >>> # 2. Retrieve
+ >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf")
+ >>> doc_scores = tf.squeeze(
+ ... tf.matmul(
+ ... tf.expand_dims(question_hidden_states, axis=1), docs_dict["retrieved_doc_embeds"], transpose_b=True
+ ... ),
+ ... axis=1,
+ ... )
+ >>> # 3. Forward to generator
+ >>> outputs = model(
+ ... inputs=None,
+ ... context_input_ids=docs_dict["context_input_ids"],
+ ... context_attention_mask=docs_dict["context_attention_mask"],
+ ... doc_scores=doc_scores,
+ ... decoder_input_ids=input_dict["labels"],
+ ... )
+
+ >>> # or directly generate
+ >>> generated = model.generate(
+ ... context_input_ids=docs_dict["context_input_ids"],
+ ... context_attention_mask=docs_dict["context_attention_mask"],
+ ... doc_scores=doc_scores,
+ ... )
+ >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
+ ```"""
+
+ assert "decoder_cached_states" not in kwargs, (
+ "Please use past_key_values to cache intermediate outputs"
+ ) # from modeling_tf_bart.py
+
+ do_marginalize = do_marginalize if do_marginalize else self.config.do_marginalize
+ reduce_loss = reduce_loss if reduce_loss else self.config.reduce_loss
+
+ if labels is not None:
+ if decoder_input_ids is None:
+ decoder_input_ids = labels
+ use_cache = False
+
+ outputs = self.rag(
+ input_ids,
+ attention_mask=attention_mask,
+ encoder_outputs=encoder_outputs,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ context_input_ids=context_input_ids,
+ context_attention_mask=context_attention_mask,
+ doc_scores=doc_scores,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_retrieved=output_retrieved,
+ n_docs=n_docs,
+ training=training,
+ )
+
+ loss = None
+ logits = outputs.logits
+ if labels is not None:
+ assert decoder_input_ids is not None
+ loss = self.get_nll(
+ outputs.logits,
+ outputs.doc_scores,
+ labels,
+ reduce_loss=reduce_loss,
+ epsilon=self.config.label_smoothing,
+ n_docs=n_docs,
+ )
+
+ if do_marginalize:
+ logits = self.marginalize(logits, outputs.doc_scores, n_docs)
+
+ return TFRetrievAugLMMarginOutput(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ doc_scores=outputs.doc_scores,
+ context_input_ids=outputs.context_input_ids,
+ context_attention_mask=outputs.context_attention_mask,
+ retrieved_doc_embeds=outputs.retrieved_doc_embeds,
+ retrieved_doc_ids=outputs.retrieved_doc_ids,
+ question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state,
+ question_enc_hidden_states=outputs.question_enc_hidden_states,
+ question_enc_attentions=outputs.question_enc_attentions,
+ generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state,
+ generator_enc_hidden_states=outputs.generator_enc_hidden_states,
+ generator_enc_attentions=outputs.generator_enc_attentions,
+ generator_dec_hidden_states=outputs.generator_dec_hidden_states,
+ generator_dec_attentions=outputs.generator_dec_attentions,
+ )
+
+ def generate(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: tf.Tensor | None = None,
+ context_input_ids=None,
+ context_attention_mask=None,
+ doc_scores=None,
+ n_docs=None,
+ generation_config=None,
+ logits_processor=TFLogitsProcessorList(),
+ **kwargs,
+ ):
+ """
+ Implements TFRAG token decoding.
+
+ Args:
+ input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ The sequence used as a prompt for the generation. If `input_ids` is not passed, then
+ `context_input_ids` has to be provided.
+ attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ context_input_ids (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever.
+
+ If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
+ forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
+ context_attention_mask (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever.
+
+ If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
+ forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
+ doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`):
+ Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
+ `question_encoder_last_hidden_state`.
+
+ If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
+ forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
+ n_docs (`int`, *optional*, defaults to `config.n_docs`)
+ Number of documents to retrieve and/or number of documents for which to generate an answer.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which had the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`TFLogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and a
+ model's config. If a logit processor is passed that is already created with the arguments or a model's
+ config an error is thrown.
+ kwargs (`dict[str, Any]`, *optional*):
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model.
+
+ Return:
+ `tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The
+ second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
+ due to the `eos_token_id`.
+ """
+ # Handle `generation_config` and kwargs that might update it
+ if generation_config is None:
+ generation_config = self.generation_config
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
+
+ # set default parameters
+ n_docs = n_docs if n_docs is not None else self.config.n_docs
+
+ # retrieve docs
+ if self.retriever is not None and context_input_ids is None:
+ question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
+ out = self.retriever(
+ input_ids,
+ question_hidden_states.numpy().astype(np.float32),
+ prefix=self.generator.config.prefix,
+ n_docs=n_docs,
+ return_tensors="tf",
+ )
+ context_input_ids, context_attention_mask, retrieved_doc_embeds = (
+ out["context_input_ids"],
+ out["context_attention_mask"],
+ out["retrieved_doc_embeds"],
+ )
+
+ context_input_ids = tf.cast(context_input_ids, tf.int32)
+ context_attention_mask = tf.cast(context_attention_mask, tf.int32)
+ retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32)
+
+ # compute doc_scores
+ doc_scores = tf.matmul(
+ tf.expand_dims(question_hidden_states, axis=1), retrieved_doc_embeds, transpose_b=True
+ )
+ doc_scores = tf.squeeze(doc_scores, axis=1)
+
+ assert (context_input_ids.shape[0] % n_docs) == 0, (
+ f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
+ f" {context_input_ids.shape[0]}."
+ )
+
+ batch_size = context_input_ids.shape[0] // n_docs
+
+ encoder = self.rag.generator.get_encoder()
+ encoder_outputs = encoder(
+ input_ids=context_input_ids,
+ attention_mask=context_attention_mask,
+ output_attentions=generation_config.output_attentions,
+ output_hidden_states=generation_config.output_hidden_states,
+ return_dict=True,
+ )
+
+ decoder_input_ids = tf.fill(
+ (batch_size * generation_config.num_beams, 1),
+ tf.cast(generation_config.decoder_start_token_id, tf.int32),
+ )
+ last_hidden_state = encoder_outputs["last_hidden_state"]
+
+ def extend_enc_output(tensor, num_beams=None):
+ """
+ Broadcast tensor with `num_beams` replica, with correct order Input: tensor of shape (batch_size*n_docs ,
+ d) Output: tensor of shape (batch_size*num_beams*n_docs , d)
+ """
+
+ # expand batch_size & num_beam dimensions
+ d_shape_list = tensor.shape[1:]
+
+ # split n_docs dimensions
+ new_shape = (batch_size, 1, n_docs) + d_shape_list
+ tensor = tf.reshape(tensor, new_shape)
+
+ # repeat same last hidden states over `num_beams` dimension
+ new_shape = (batch_size, num_beams, n_docs) + d_shape_list
+ tensor = tf.broadcast_to(tensor, new_shape)
+
+ # merge `batch_size`, `num_beams`, `num_docs` dims again
+ new_shape = (batch_size * num_beams * n_docs,) + d_shape_list
+ return tf.reshape(tensor, new_shape)
+
+ # correctly extend last_hidden_state and attention mask
+ context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams)
+ encoder_outputs["last_hidden_state"] = extend_enc_output(
+ last_hidden_state, num_beams=generation_config.num_beams
+ )
+
+ doc_scores = tf.repeat(doc_scores, generation_config.num_beams, axis=0)
+
+ # define start_len & additional parameters
+ model_kwargs["doc_scores"] = doc_scores
+ model_kwargs["encoder_outputs"] = encoder_outputs
+ model_kwargs["attention_mask"] = context_attention_mask
+ model_kwargs["n_docs"] = n_docs
+
+ pre_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=tf.shape(decoder_input_ids)[-1],
+ logits_processor=logits_processor,
+ )
+
+ if generation_config.num_beams == 1:
+ return self.greedy_search(
+ input_ids=decoder_input_ids,
+ max_length=generation_config.max_length,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ logits_processor=pre_processor,
+ output_attentions=generation_config.output_attentions,
+ output_hidden_states=generation_config.output_hidden_states,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ **model_kwargs,
+ )
+ elif generation_config.num_beams > 1:
+ if generation_config.num_beams < generation_config.num_return_sequences:
+ raise ValueError(
+ "Beam search decoding cannot return more sequences than it has beams. Please set num_beams >="
+ f" num_return_sequences, got {generation_config.num_beams} and"
+ f" {generation_config.num_return_sequences} (respectively)"
+ )
+
+ def unflatten_beam_dim(tensor):
+ """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
+ shape = shape_list(tensor)
+ return tf.reshape(tensor, [-1, generation_config.num_beams] + shape[1:])
+
+ decoder_input_ids = unflatten_beam_dim(decoder_input_ids)
+ model_kwargs["attention_mask"] = unflatten_beam_dim(model_kwargs["attention_mask"])
+ model_kwargs["encoder_outputs"]["last_hidden_state"] = unflatten_beam_dim(
+ model_kwargs["encoder_outputs"]["last_hidden_state"]
+ )
+
+ return self.beam_search(
+ input_ids=decoder_input_ids,
+ max_length=generation_config.max_length,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ logits_processor=pre_processor,
+ output_attentions=generation_config.output_attentions,
+ output_hidden_states=generation_config.output_hidden_states,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ **model_kwargs,
+ )
+ else:
+ raise ValueError(
+ f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}"
+ )
+
+ def get_input_embeddings(self):
+ return self.rag.generator.get_input_embeddings()
+
+ def get_output_embeddings(self):
+ return self.rag.generator.get_output_embeddings()
+
+ # Adapted from tf_t5's & tf_bart's _shift_right
+ def shift_tokens_right(self, input_ids, start_token_id=None):
+ """Shift input ids one token to the right, and pad with start_token_id"""
+
+ if start_token_id is None:
+ start_token_id = self.generator.config.decoder_start_token_id
+ assert start_token_id is not None, (
+ "self.generator.config.decoder_start_token_id has to be defined. In Rag we commonly use Bart as"
+ " generator, see Bart docs for more information"
+ )
+
+ pad_token_id = self.generator.config.pad_token_id
+ assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
+
+ start_tokens = tf.fill((shape_list(input_ids)[0], 1), tf.cast(start_token_id, input_ids.dtype))
+ shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
+
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids = tf.where(
+ shifted_input_ids == -100,
+ tf.fill(shape_list(shifted_input_ids), tf.cast(pad_token_id, input_ids.dtype)),
+ shifted_input_ids,
+ )
+
+ # "Verify that `labels` has only positive values and -100"
+ assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, shifted_input_ids.dtype))
+
+ # Make sure the assertion op is called by wrapping the result in an identity no-op
+ with tf.control_dependencies([assert_gte0]):
+ shifted_input_ids = tf.identity(shifted_input_ids)
+
+ return shifted_input_ids
+
+ # nll stands for 'negative log likelihood'
+ def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, n_docs=None):
+ n_docs = n_docs if n_docs is not None else self.config.n_docs
+ # shift tokens left (from original Pytorch's version)
+
+ target = tf.concat(
+ [target[:, 1:], tf.fill([target.shape[0], 1], tf.cast(self.config.generator.pad_token_id, target.dtype))],
+ axis=1,
+ )
+ rag_logprobs = self.marginalize(seq_logits, doc_scores, n_docs)
+ loss = self.hf_compute_loss(target, rag_logprobs, from_logits=True, reduce_loss=reduce_loss)
+
+ return loss
+
+ # Adopted modeling_tf_bart + add smooth_loss to match with pytorch version
+ def hf_compute_loss(self, labels, y_pred, smooth_epsilon=0.0, from_logits=True, reduce_loss=False):
+ """CrossEntropyLoss that ignores pad tokens"""
+ # Matt: As written, this loss is not XLA-compatible, but it's doing some very weird things
+ # and I don't feel comfortable converting it.
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True,
+ reduction=keras.losses.Reduction.SUM,
+ )
+
+ if from_logits is False: # convert to logits
+ eps = 1e-9
+ y_pred = tf.clip_by_value(y_pred, clip_value_min=eps, clip_value_max=1 - eps)
+ y_pred = tf.math.log(y_pred)
+
+ logits = y_pred
+ melted_labels = tf.reshape(labels, (-1,))
+ active_loss = tf.not_equal(melted_labels, self.config.generator.pad_token_id)
+
+ reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, logits.shape[2])), active_loss)
+ labels = tf.boolean_mask(melted_labels, active_loss)
+ nll_loss = loss_fn(labels, reduced_logits)
+
+ smooth_loss = -tf.reduce_sum(reduced_logits, axis=-1)
+ smooth_loss = tf.reduce_sum(smooth_loss) # sum and squeeze like torch
+ eps_i = smooth_epsilon / reduced_logits.shape[-1]
+
+ loss = (1.0 - smooth_epsilon) * nll_loss + eps_i * smooth_loss
+
+ return loss
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "rag", None) is not None:
+ with tf.name_scope(self.rag.name):
+ self.rag.build(None)
+
+
+@add_start_docstrings_to_model_forward(
+ """
+ A TF RAG-sequence model implementation. It performs RAG-sequence specific marginalization in the forward pass.
+ """,
+ RAG_START_DOCSTRING,
+)
+class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss):
+ load_weight_prefix = "tf_rag_sequence_for_generation_1/rag"
+
+ def __init__(
+ self,
+ config: PretrainedConfig | None = None,
+ question_encoder: TFPreTrainedModel | None = None,
+ generator: TFPreTrainedModel | None = None,
+ retriever: RagRetriever | None = None,
+ **kwargs,
+ ):
+ assert config is not None or (question_encoder is not None and generator is not None), (
+ "Either a configuration or an encoder and a generator has to be provided."
+ )
+
+ if config is None:
+ config = RagConfig.from_question_encoder_generator_configs(
+ question_encoder.config, generator.config, **kwargs
+ )
+
+ super().__init__(config)
+
+ # instantiate model
+ self.rag = TFRagModel(
+ config=config,
+ question_encoder=question_encoder,
+ generator=generator,
+ retriever=retriever,
+ load_weight_prefix=self.load_weight_prefix,
+ name="rag",
+ )
+
+ def set_retriever(self, retriever: RagRetriever):
+ self.rag.retriever = retriever
+
+ @property
+ def retriever(self):
+ return self.rag.retriever
+
+ @property
+ def generator(self):
+ return self.rag.generator
+
+ @property
+ def question_encoder(self):
+ return self.rag.question_encoder
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_input_ids: np.ndarray | tf.Tensor | None = None,
+ decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ encoder_outputs: np.ndarray | tf.Tensor | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ doc_scores: np.ndarray | tf.Tensor | None = None,
+ context_input_ids: np.ndarray | tf.Tensor | None = None,
+ context_attention_mask: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ output_retrieved: bool | None = None,
+ n_docs: int | None = None,
+ exclude_bos_score: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ reduce_loss: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ **kwargs, # needs kwargs for generation
+ ) -> tuple[tf.Tensor] | TFRetrievAugLMMarginOutput:
+ r"""
+ exclude_bos_score (`bool`, *optional*):
+ Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing
+ the loss.
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the cross entropy classification loss according to Rag-Sequence model formulation See
+ https://huggingface.co/papers/2005.11401 Section 2.1 for details about Rag-Sequence formulation. Indices should
+ be in `[0, ..., config.vocab_size - 1]`.
+ reduce_loss (`bool`, *optional*):
+ Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `tf.Tensor.sum`
+ operation.
+ kwargs (`dict[str, any]`, *optional*, defaults to `{}`):
+ Legacy dictionary, which is required so that model can use *generate()* function.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, RagRetriever, TFRagSequenceForGeneration
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
+ >>> retriever = RagRetriever.from_pretrained(
+ ... "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
+ ... )
+ >>> # initialize with RagRetriever to do everything in one forward call
+ >>> model = TFRagSequenceForGeneration.from_pretrained(
+ ... "facebook/rag-sequence-nq", retriever=retriever, from_pt=True
+ ... )
+
+ >>> input_dict = tokenizer.prepare_seq2seq_batch(
+ ... "How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf"
+ ... )
+ >>> outputs = model(input_dict, output_retrieved=True)
+
+ >>> # or use retriever separately
+ >>> # 1. Encode
+ >>> input_ids = input_dict["input_ids"]
+ >>> question_hidden_states = model.question_encoder(input_ids)[0]
+ >>> # 2. Retrieve
+ >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf")
+ >>> doc_scores = tf.squeeze(
+ ... tf.matmul(
+ ... tf.expand_dims(question_hidden_states, axis=1), docs_dict["retrieved_doc_embeds"], transpose_b=True
+ ... ),
+ ... axis=1,
+ ... )
+ >>> # 3. Forward to generator
+ >>> outputs = model(
+ ... inputs=None,
+ ... context_input_ids=docs_dict["context_input_ids"],
+ ... context_attention_mask=docs_dict["context_attention_mask"],
+ ... doc_scores=doc_scores,
+ ... decoder_input_ids=input_dict["labels"],
+ ... )
+
+ >>> # or directly generate
+ >>> generated = model.generate(
+ ... context_input_ids=docs_dict["context_input_ids"],
+ ... context_attention_mask=docs_dict["context_attention_mask"],
+ ... doc_scores=doc_scores,
+ ... )
+ >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
+ ```"""
+
+ assert "decoder_cached_states" not in kwargs, (
+ "Please use past_key_values to cache intermediate outputs"
+ ) # from modeling_tf_bart.py
+
+ exclude_bos_score = exclude_bos_score if exclude_bos_score else self.config.exclude_bos_score
+ reduce_loss = reduce_loss if reduce_loss else self.config.reduce_loss
+
+ if labels is not None:
+ if decoder_input_ids is None:
+ decoder_input_ids = labels
+ use_cache = False
+
+ outputs = self.rag(
+ input_ids,
+ attention_mask=attention_mask,
+ encoder_outputs=encoder_outputs,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ context_input_ids=context_input_ids,
+ context_attention_mask=context_attention_mask,
+ doc_scores=doc_scores,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_retrieved=output_retrieved,
+ n_docs=n_docs,
+ training=training,
+ )
+
+ loss = None
+ if labels is not None:
+ loss = self.get_nll(
+ outputs.logits,
+ outputs.doc_scores,
+ labels,
+ reduce_loss=reduce_loss,
+ epsilon=self.config.label_smoothing,
+ n_docs=n_docs,
+ )
+
+ return TFRetrievAugLMMarginOutput(
+ loss=loss,
+ logits=outputs.logits,
+ doc_scores=outputs.doc_scores,
+ past_key_values=outputs.past_key_values,
+ context_input_ids=outputs.context_input_ids,
+ context_attention_mask=outputs.context_attention_mask,
+ retrieved_doc_embeds=outputs.retrieved_doc_embeds,
+ retrieved_doc_ids=outputs.retrieved_doc_ids,
+ question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state,
+ question_enc_hidden_states=outputs.question_enc_hidden_states,
+ question_enc_attentions=outputs.question_enc_attentions,
+ generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state,
+ generator_enc_hidden_states=outputs.generator_enc_hidden_states,
+ generator_enc_attentions=outputs.generator_enc_attentions,
+ generator_dec_hidden_states=outputs.generator_dec_hidden_states,
+ generator_dec_attentions=outputs.generator_dec_attentions,
+ )
+
+ def get_nll(
+ self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, exclude_bos_score=False, n_docs=None
+ ):
+ # shift tokens left
+ target = tf.concat(
+ [target[:, 1:], tf.fill([target.shape[0], 1], tf.cast(self.config.generator.pad_token_id, target.dtype))],
+ axis=1,
+ )
+
+ # bos_token_id is None for T5
+ bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id
+ n_docs = n_docs if n_docs is not None else self.config.n_docs
+ equal_bos_token_id_all = tf.reduce_all(tf.equal(target[:, 0], bos_token_id))
+ use_bos = bos_token_id is not None and equal_bos_token_id_all
+
+ def _mask_pads(ll, smooth_obj):
+ pad_mask = tf.equal(target, tf.cast(self.config.generator.pad_token_id, target.dtype))
+ if tf.reduce_any(pad_mask):
+ ll = tf.where(pad_mask, 0.0, ll)
+ smooth_obj = tf.where(pad_mask, 0.0, smooth_obj)
+ return tf.squeeze(ll, axis=-1), tf.squeeze(smooth_obj, axis=-1)
+
+ # seq_logits.shape = (batch*n_docs, tgt_len , vocabs)
+ seq_logprobs = tf.nn.log_softmax(seq_logits, axis=-1)
+ seq_logprobs = tf.reshape(
+ seq_logprobs, (seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.shape[-1])
+ ) # (batch_size, n_docs, tgt_len, vocabs)
+ doc_logprobs = tf.nn.log_softmax(doc_scores, axis=1)
+ doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1)
+ doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1) # done twice to get 4-D
+
+ # RAG-sequence marginalization
+ first_token_scores = seq_logprobs[:, :, :1, :]
+ second_token_scores = seq_logprobs[:, :, 1:2, :]
+ remainder = seq_logprobs[:, :, 2:, :]
+ rag_logprobs = tf.concat([first_token_scores, second_token_scores + doc_logprobs, remainder], axis=2)
+
+ # calculate loss
+ target = tf.expand_dims(target, axis=1) # n_docs dimension
+ target = tf.expand_dims(target, axis=-1) # logits dimension
+ target = tf.repeat(target, n_docs, axis=1)
+ assert len(target.shape) == len(rag_logprobs.shape)
+
+ # last-axis gathering only - use 2D-reshape-trick for Torch's style nD gathering
+ def torch_gather(param, id_tensor):
+ # 2d-gather torch equivalent: https://stackoverflow.com/questions/52129909/tensorflow-equivalent-of-torch-gather
+ def gather2d(target, id_tensor):
+ idx = tf.stack([tf.range(tf.shape(id_tensor)[0], dtype=id_tensor.dtype), id_tensor[:, 0]], axis=-1)
+ result = tf.gather_nd(target, idx)
+ return tf.expand_dims(result, axis=-1)
+
+ target = tf.reshape(param, (-1, param.shape[-1])) # reshape 2D
+ target_shape = id_tensor.shape
+
+ id_tensor = tf.reshape(id_tensor, (-1, 1)) # also 2D-index
+ result = gather2d(target, id_tensor)
+ return tf.reshape(result, target_shape)
+
+ ll = torch_gather(rag_logprobs, id_tensor=target)
+ smooth_obj = tf.reduce_sum(rag_logprobs, axis=-1, keepdims=True) # total sum of all (normalised) logits
+
+ ll, smooth_obj = _mask_pads(ll, smooth_obj)
+
+ # sum over tokens, exclude bos while scoring
+ if exclude_bos_score and use_bos:
+ ll = tf.reduce_sum(ll[:, :, 1:], axis=2)
+ else:
+ ll = tf.reduce_sum(ll, axis=2)
+
+ smooth_obj = tf.reduce_sum(smooth_obj, axis=2)
+ ll = tf.math.reduce_logsumexp(ll, axis=1) # logsumexp over docs
+ smooth_obj = tf.math.reduce_logsumexp(smooth_obj, axis=1)
+
+ nll_loss = -ll
+ smooth_loss = -smooth_obj
+
+ if reduce_loss:
+ nll_loss = tf.reduce_sum(nll_loss)
+ smooth_loss = tf.reduce_sum(smooth_loss)
+
+ eps_i = epsilon / rag_logprobs.shape[-1]
+ loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
+ return loss
+
+ def generate(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: tf.Tensor | None = None,
+ context_input_ids=None,
+ context_attention_mask=None,
+ doc_scores=None,
+ do_deduplication=None, # defaults to True
+ num_return_sequences=None, # defaults to 1
+ num_beams=None, # defaults to 1
+ n_docs=None,
+ **model_kwargs,
+ ):
+ """
+ Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation
+ for more information on how to set other generate input parameters
+
+ Args:
+ input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ The sequence used as a prompt for the generation. If `input_ids` is not passed, then
+ `context_input_ids` has to be provided.
+ attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for
+ tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention
+ masks?](../glossary#attention-mask)
+ context_input_ids (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Input IDs post-processed from the retrieved documents and the question encoder input_ids by the
+ retriever.
+ context_attention_mask (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
+ Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
+ retriever. If the model has is not initialized with a `retriever` or `input_ids` is not given,
+ `context_input_ids` and `context_attention_mask` have to be provided to the forward pass. They are
+ returned by [`~RagRetriever.__call__`].
+ doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`):
+ Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
+ `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` or
+ `input_ids` is not given, `doc_scores` has to be provided to the forward pass. `doc_scores` are
+ returned by [`~RagRetriever.__call__`].
+ do_deduplication (`bool`, *optional*):
+ Whether or not to deduplicate the generations from different context documents for a given input. Has
+ to be set to `False` if used while training with distributed backend.
+ num_return_sequences(`int`, *optional*, defaults to 1):
+ The number of independently computed returned sequences for each element in the batch. Note that this
+ is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function,
+ where we set `num_return_sequences` to `num_beams`.
+ num_beams (`int`, *optional*, defaults to 1):
+ Number of beams for beam search. 1 means no beam search.
+ n_docs (`int`, *optional*, defaults to `config.n_docs`)
+ Number of documents to retrieve and/or number of documents for which to generate an answer.
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional kwargs will be passed to [`~generation.GenerationMixin.generate`]
+
+ Return:
+ `tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The
+ second dimension (sequence length) is either equal to `max_length` or shorter if all batches finished early
+ due to the `eos_token_id`.
+ """
+
+ n_docs = n_docs if n_docs is not None else self.config.n_docs
+ do_deduplication = do_deduplication if do_deduplication is not None else self.config.do_deduplication
+ num_doc_return_sequences = (
+ num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
+ )
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
+
+ assert input_ids is not None or context_input_ids is not None, (
+ " At least one of input_ids or context_input_ids must be given"
+ )
+
+ if self.retriever is not None and context_input_ids is None:
+ question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
+ context_input_ids = self.retriever(
+ input_ids,
+ question_hidden_states.numpy(),
+ prefix=self.generator.config.prefix,
+ n_docs=n_docs,
+ return_tensors="tf",
+ )["context_input_ids"]
+
+ hypos = []
+ model_kwargs["num_beams"] = num_beams
+ model_kwargs["num_return_sequences"] = num_beams # put here so that not confused with num_doc_return_sequences
+ model_kwargs["attention_mask"] = None
+
+ batch_size = input_ids.shape[0] if input_ids is not None else context_input_ids.shape[0] // n_docs
+
+ for index in range(batch_size):
+ # first, generate beams from documents:
+ generator_input_ids = context_input_ids[index * n_docs : (index + 1) * n_docs] # (n_docs, max_len)
+
+ output_sequences = self.generator.generate(
+ generator_input_ids,
+ **model_kwargs,
+ ) # n_docs * n_beam, tgt_len
+ if do_deduplication:
+ # do_deduplication -- for TF, work on Eager mode only!
+ output_sequences = tf.stack(list({str(k.numpy().tolist()): k for k in output_sequences}.values()))
+
+ num_candidates = output_sequences.shape[
+ 0
+ ] # after deduplication, this number can be less than n_docs*n_beam
+
+ # then, run model forwards to get nll scores:
+ if input_ids is not None:
+ new_input_ids = tf.tile(input_ids[index : index + 1], (num_candidates, 1))
+ outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
+ else: # input_ids is None, need context_input_ids/mask and doc_scores
+ assert context_attention_mask is not None, (
+ "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you"
+ " can set a retriever using the `set_retriever(...)` function."
+ )
+ assert doc_scores is not None, (
+ "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a"
+ " retriever using the `set_retriever(...)` function."
+ )
+
+ individual_input_ids = tf.tile(
+ generator_input_ids, (num_candidates, 1)
+ ) # (num_candidates*n_docs, max_len)
+
+ individual_attention_mask = context_attention_mask[index * n_docs : (index + 1) * n_docs]
+ individual_attention_mask = tf.tile(individual_attention_mask, (num_candidates, 1))
+
+ individual_doc_scores = doc_scores[index : (index + 1), :] # doc_scores.shape = [batch, n_docs]
+ individual_doc_scores = tf.tile(individual_doc_scores, (num_candidates, 1)) # [num_candidates, n_docs]
+
+ outputs = self(
+ input_ids=None,
+ context_input_ids=individual_input_ids,
+ context_attention_mask=individual_attention_mask,
+ doc_scores=individual_doc_scores,
+ labels=output_sequences,
+ exclude_bos_score=True,
+ )
+
+ top_cand_inds = tf.math.top_k((-outputs["loss"]), k=num_doc_return_sequences)[1]
+
+ # add hypothesis
+ hypos.append(tf.gather(output_sequences, top_cand_inds))
+
+ return self._cat_and_pad(hypos, pad_token_id=self.config.generator.pad_token_id)
+
+ @staticmethod
+ def _cat_and_pad(tensors, pad_token_id):
+ # used by generate(): tensors is a (batched) list of (candidates, len); len is varied across batch
+
+ # Initialize padded tensor with shape ( all_candidates , max_candidate_length ),
+ # where all_candidates counted from all inputs
+ new_shape = sum([t.shape[0] for t in tensors]), max([t.shape[1] for t in tensors])
+ output = tf.fill(new_shape, pad_token_id)
+
+ # Normal tensor doesn't support slice assignment, so we need tf.Variable
+ output = tf.Variable(output)
+
+ # Assign, and then convert back to tensor
+ ind = 0
+ for t in tensors:
+ output[ind : ind + t.shape[0], : t.shape[1]].assign(t)
+ ind += t.shape[0]
+
+ output = tf.convert_to_tensor(output)
+ return tf.cast(output, tensors[0][0][0].dtype)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "rag", None) is not None:
+ with tf.name_scope(self.rag.name):
+ self.rag.build(None)
+
+
+__all__ = ["TFRagModel", "TFRagPreTrainedModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/retrieval_rag.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/retrieval_rag.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c4548cba6f118db47be0c8ac57d179409cd679d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/retrieval_rag.py
@@ -0,0 +1,679 @@
+# coding=utf-8
+# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""RAG Retriever model implementation."""
+
+import os
+import pickle
+import time
+from collections.abc import Iterable
+from typing import Optional
+
+import numpy as np
+
+from ...tokenization_utils import PreTrainedTokenizer
+from ...tokenization_utils_base import BatchEncoding
+from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends, strtobool
+from .configuration_rag import RagConfig
+from .tokenization_rag import RagTokenizer
+
+
+if is_datasets_available():
+ from datasets import Dataset, load_dataset, load_from_disk
+
+if is_faiss_available():
+ import faiss
+
+
+logger = logging.get_logger(__name__)
+
+
+LEGACY_INDEX_PATH = "https://storage.googleapis.com/huggingface-nlp/datasets/wiki_dpr/"
+
+
+class Index:
+ """
+ A base class for the Indices encapsulated by the [`RagRetriever`].
+ """
+
+ def get_doc_dicts(self, doc_ids: np.ndarray) -> list[dict]:
+ """
+ Returns a list of dictionaries, containing titles and text of the retrieved documents.
+
+ Args:
+ doc_ids (`np.ndarray` of shape `(batch_size, n_docs)`):
+ A tensor of document indices.
+ """
+ raise NotImplementedError
+
+ def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> tuple[np.ndarray, np.ndarray]:
+ """
+ For each query in the batch, retrieves `n_docs` documents.
+
+ Args:
+ question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`):
+ An array of query vectors.
+ n_docs (`int`):
+ The number of docs retrieved per query.
+
+ Returns:
+ `np.ndarray` of shape `(batch_size, n_docs)`: A tensor of indices of retrieved documents. `np.ndarray` of
+ shape `(batch_size, vector_size)`: A tensor of vector representations of retrieved documents.
+ """
+ raise NotImplementedError
+
+ def is_initialized(self):
+ """
+ Returns `True` if index is already initialized.
+ """
+ raise NotImplementedError
+
+ def init_index(self):
+ """
+ A function responsible for loading the index into memory. Should be called only once per training run of a RAG
+ model. E.g. if the model is trained on multiple GPUs in a distributed setup, only one of the workers will load
+ the index.
+ """
+ raise NotImplementedError
+
+
+class LegacyIndex(Index):
+ """
+ An index which can be deserialized from the files built using https://github.com/facebookresearch/DPR. We use
+ default faiss index parameters as specified in that repository.
+
+ Args:
+ vector_size (`int`):
+ The dimension of indexed vectors.
+ index_path (`str`):
+ A path to a *directory* containing index files compatible with [`~models.rag.retrieval_rag.LegacyIndex`]
+ """
+
+ INDEX_FILENAME = "hf_bert_base.hnswSQ8_correct_phi_128.c_index"
+ PASSAGE_FILENAME = "psgs_w100.tsv.pkl"
+
+ def __init__(self, vector_size, index_path):
+ requires_backends(self, ["faiss"])
+ self.index_id_to_db_id = []
+ self.index_path = index_path
+ self.passages = self._load_passages()
+ self.vector_size = vector_size
+ self.index = None
+ self._index_initialized = False
+
+ def _resolve_path(self, index_path, filename):
+ is_local = os.path.isdir(index_path)
+ try:
+ # Load from URL or cache if already cached
+ resolved_archive_file = cached_file(index_path, filename)
+ except OSError:
+ msg = (
+ f"Can't load '{filename}'. Make sure that:\n\n"
+ f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}\n\n"
+ f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
+ )
+ raise OSError(msg)
+ if is_local:
+ logger.info(f"loading file {resolved_archive_file}")
+ else:
+ logger.info(f"loading file {filename} from cache at {resolved_archive_file}")
+ return resolved_archive_file
+
+ def _load_passages(self):
+ logger.info(f"Loading passages from {self.index_path}")
+ passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME)
+ if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
+ raise ValueError(
+ "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
+ "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
+ "that could have been tampered with. If you already verified the pickle data and decided to use it, "
+ "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
+ )
+ with open(passages_path, "rb") as passages_file:
+ passages = pickle.load(passages_file)
+ return passages
+
+ def _deserialize_index(self):
+ logger.info(f"Loading index from {self.index_path}")
+ resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr")
+ self.index = faiss.read_index(resolved_index_path)
+ resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr")
+ if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
+ raise ValueError(
+ "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
+ "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
+ "that could have been tampered with. If you already verified the pickle data and decided to use it, "
+ "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
+ )
+ with open(resolved_meta_path, "rb") as metadata_file:
+ self.index_id_to_db_id = pickle.load(metadata_file)
+ assert len(self.index_id_to_db_id) == self.index.ntotal, (
+ "Deserialized index_id_to_db_id should match faiss index size"
+ )
+
+ def is_initialized(self):
+ return self._index_initialized
+
+ def init_index(self):
+ index = faiss.IndexHNSWFlat(self.vector_size + 1, 512)
+ index.hnsw.efSearch = 128
+ index.hnsw.efConstruction = 200
+ self.index = index
+ self._deserialize_index()
+ self._index_initialized = True
+
+ def get_doc_dicts(self, doc_ids: np.ndarray):
+ doc_list = []
+ for doc_ids_i in doc_ids:
+ ids = [str(int(doc_id)) for doc_id in doc_ids_i]
+ docs = [self.passages[doc_id] for doc_id in ids]
+ doc_list.append(docs)
+ doc_dicts = []
+ for docs in doc_list:
+ doc_dict = {}
+ doc_dict["title"] = [doc[1] for doc in docs]
+ doc_dict["text"] = [doc[0] for doc in docs]
+ doc_dicts.append(doc_dict)
+ return doc_dicts
+
+ def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> tuple[np.ndarray, np.ndarray]:
+ aux_dim = np.zeros(len(question_hidden_states), dtype="float32").reshape(-1, 1)
+ query_nhsw_vectors = np.hstack((question_hidden_states, aux_dim))
+ _, docs_ids = self.index.search(query_nhsw_vectors, n_docs)
+ vectors = [[self.index.reconstruct(int(doc_id))[:-1] for doc_id in doc_ids] for doc_ids in docs_ids]
+ ids = [[int(self.index_id_to_db_id[doc_id]) for doc_id in doc_ids] for doc_ids in docs_ids]
+ return np.array(ids), np.array(vectors)
+
+
+class HFIndexBase(Index):
+ def __init__(self, vector_size, dataset, index_initialized=False):
+ requires_backends(self, ["faiss"])
+ self.vector_size = vector_size
+ self.dataset = dataset
+ self._index_initialized = index_initialized
+ self._check_dataset_format(with_index=index_initialized)
+ dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True, dtype="float32")
+
+ def _check_dataset_format(self, with_index: bool):
+ if not isinstance(self.dataset, Dataset):
+ raise TypeError(f"Dataset should be a datasets.Dataset object, but got {type(self.dataset)}")
+ if len({"title", "text", "embeddings"} - set(self.dataset.column_names)) > 0:
+ raise ValueError(
+ "Dataset should be a dataset with the following columns: "
+ "title (str), text (str) and embeddings (arrays of dimension vector_size), "
+ f"but got columns {self.dataset.column_names}"
+ )
+ if with_index and "embeddings" not in self.dataset.list_indexes():
+ raise ValueError(
+ "Missing faiss index in the dataset. Make sure you called `dataset.add_faiss_index` to compute it "
+ "or `dataset.load_faiss_index` to load one from the disk."
+ )
+
+ def init_index(self):
+ raise NotImplementedError()
+
+ def is_initialized(self):
+ return self._index_initialized
+
+ def get_doc_dicts(self, doc_ids: np.ndarray) -> list[dict]:
+ return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])]
+
+ def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> tuple[np.ndarray, np.ndarray]:
+ _, ids = self.dataset.search_batch("embeddings", question_hidden_states, n_docs)
+ docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids]
+ vectors = [doc["embeddings"] for doc in docs]
+ for i in range(len(vectors)):
+ if len(vectors[i]) < n_docs:
+ vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))])
+ return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
+
+
+class CanonicalHFIndex(HFIndexBase):
+ """
+ A wrapper around an instance of [`~datasets.Datasets`]. If `index_path` is set to `None`, we load the pre-computed
+ index available with the [`~datasets.arrow_dataset.Dataset`], otherwise, we load the index from the indicated path
+ on disk.
+
+ Args:
+ vector_size (`int`): the dimension of the passages embeddings used by the index
+ dataset_name (`str`, optional, defaults to `wiki_dpr`):
+ A dataset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids
+ with `datasets.list_datasets()`).
+ dataset_split (`str`, optional, defaults to `train`)
+ Which split of the `dataset` to load.
+ index_name (`str`, optional, defaults to `train`)
+ The index_name of the index associated with the `dataset`. The index loaded from `index_path` will be saved
+ under this name.
+ index_path (`str`, optional, defaults to `None`)
+ The path to the serialized faiss index on disk.
+ use_dummy_dataset (`bool`, optional, defaults to `False`):
+ If True, use the dummy configuration of the dataset for tests.
+ """
+
+ def __init__(
+ self,
+ vector_size: int,
+ dataset_name: str = "wiki_dpr",
+ dataset_split: str = "train",
+ index_name: Optional[str] = None,
+ index_path: Optional[str] = None,
+ use_dummy_dataset=False,
+ dataset_revision=None,
+ ):
+ requires_backends(self, ["faiss"])
+ if int(index_path is None) + int(index_name is None) != 1:
+ raise ValueError("Please provide `index_name` or `index_path`.")
+ self.dataset_name = dataset_name
+ self.dataset_split = dataset_split
+ self.index_name = index_name
+ self.index_path = index_path
+ self.use_dummy_dataset = use_dummy_dataset
+ self.dataset_revision = dataset_revision
+ logger.info(f"Loading passages from {self.dataset_name}")
+ dataset = load_dataset(
+ self.dataset_name,
+ with_index=False,
+ split=self.dataset_split,
+ dummy=self.use_dummy_dataset,
+ revision=dataset_revision,
+ )
+ super().__init__(vector_size, dataset, index_initialized=False)
+
+ def init_index(self):
+ if self.index_path is not None:
+ logger.info(f"Loading index from {self.index_path}")
+ self.dataset.load_faiss_index("embeddings", file=self.index_path)
+ else:
+ logger.info(f"Loading index from {self.dataset_name} with index name {self.index_name}")
+ self.dataset = load_dataset(
+ self.dataset_name,
+ with_embeddings=True,
+ with_index=True,
+ split=self.dataset_split,
+ index_name=self.index_name,
+ dummy=self.use_dummy_dataset,
+ revision=self.dataset_revision,
+ )
+ self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
+ self._index_initialized = True
+
+
+class CustomHFIndex(HFIndexBase):
+ """
+ A wrapper around an instance of [`~datasets.Datasets`]. The dataset and the index are both loaded from the
+ indicated paths on disk.
+
+ Args:
+ vector_size (`int`): the dimension of the passages embeddings used by the index
+ dataset_path (`str`):
+ The path to the serialized dataset on disk. The dataset should have 3 columns: title (str), text (str) and
+ embeddings (arrays of dimension vector_size)
+ index_path (`str`)
+ The path to the serialized faiss index on disk.
+ """
+
+ def __init__(self, vector_size: int, dataset, index_path=None):
+ requires_backends(self, ["faiss"])
+ super().__init__(vector_size, dataset, index_initialized=index_path is None)
+ self.index_path = index_path
+
+ @classmethod
+ def load_from_disk(cls, vector_size, dataset_path, index_path):
+ logger.info(f"Loading passages from {dataset_path}")
+ if dataset_path is None or index_path is None:
+ raise ValueError(
+ "Please provide `dataset_path` and `index_path` after calling `dataset.save_to_disk(dataset_path)` "
+ "and `dataset.get_index('embeddings').save(index_path)`."
+ )
+ dataset = load_from_disk(dataset_path)
+ return cls(vector_size=vector_size, dataset=dataset, index_path=index_path)
+
+ def init_index(self):
+ if not self.is_initialized():
+ logger.info(f"Loading index from {self.index_path}")
+ self.dataset.load_faiss_index("embeddings", file=self.index_path)
+ self._index_initialized = True
+
+
+class RagRetriever:
+ """
+ Retriever used to get documents from vector queries. It retrieves the documents embeddings as well as the documents
+ contents, and it formats them to be used with a RagModel.
+
+ Args:
+ config ([`RagConfig`]):
+ The configuration of the RAG model this Retriever is used with. Contains parameters indicating which
+ `Index` to build. You can load your own custom dataset with `config.index_name="custom"` or use a canonical
+ one (default) from the datasets library with `config.index_name="wiki_dpr"` for example.
+ question_encoder_tokenizer ([`PreTrainedTokenizer`]):
+ The tokenizer that was used to tokenize the question. It is used to decode the question and then use the
+ generator_tokenizer.
+ generator_tokenizer ([`PreTrainedTokenizer`]):
+ The tokenizer used for the generator part of the RagModel.
+ index ([`~models.rag.retrieval_rag.Index`], optional, defaults to the one defined by the configuration):
+ If specified, use this index instead of the one built using the configuration
+
+ Examples:
+
+ ```python
+ >>> # To load the default "wiki_dpr" dataset with 21M passages from wikipedia (index name is 'compressed' or 'exact')
+ >>> from transformers import RagRetriever
+
+ >>> retriever = RagRetriever.from_pretrained(
+ ... "facebook/dpr-ctx_encoder-single-nq-base", dataset="wiki_dpr", index_name="compressed"
+ ... )
+
+ >>> # To load your own indexed dataset built with the datasets library. More info on how to build the indexed dataset in examples/rag/use_own_knowledge_dataset.py
+ >>> from transformers import RagRetriever
+
+ >>> dataset = (
+ ... ...
+ ... ) # dataset must be a datasets.Datasets object with columns "title", "text" and "embeddings", and it must have a supported index (e.g., Faiss or other index types depending on your setup)
+ >>> retriever = RagRetriever.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", indexed_dataset=dataset)
+
+ >>> # To load your own indexed dataset built with the datasets library that was saved on disk. More info in examples/rag/use_own_knowledge_dataset.py
+ >>> from transformers import RagRetriever
+
+ >>> dataset_path = "path/to/my/dataset" # dataset saved via *dataset.save_to_disk(...)*
+ >>> index_path = "path/to/my/index" # index saved via *dataset.get_index("embeddings").save(...)*
+ >>> retriever = RagRetriever.from_pretrained(
+ ... "facebook/dpr-ctx_encoder-single-nq-base",
+ ... index_name="custom",
+ ... passages_path=dataset_path,
+ ... index_path=index_path,
+ ... )
+
+ >>> # To load the legacy index built originally for Rag's paper
+ >>> from transformers import RagRetriever
+
+ >>> retriever = RagRetriever.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", index_name="legacy")
+ ```"""
+
+ def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None, init_retrieval=True):
+ self._init_retrieval = init_retrieval
+ requires_backends(self, ["datasets"])
+ super().__init__()
+ self.index = index or self._build_index(config)
+ self.generator_tokenizer = generator_tokenizer
+ self.question_encoder_tokenizer = question_encoder_tokenizer
+
+ self.n_docs = config.n_docs
+ self.batch_size = config.retrieval_batch_size
+
+ self.config = config
+ if self._init_retrieval:
+ self.init_retrieval()
+
+ self.ctx_encoder_tokenizer = None
+ self.return_tokenized_docs = False
+
+ @staticmethod
+ def _build_index(config):
+ if config.index_name == "legacy":
+ return LegacyIndex(
+ config.retrieval_vector_size,
+ config.index_path or LEGACY_INDEX_PATH,
+ )
+ elif config.index_name == "custom":
+ return CustomHFIndex.load_from_disk(
+ vector_size=config.retrieval_vector_size,
+ dataset_path=config.passages_path,
+ index_path=config.index_path,
+ )
+ else:
+ return CanonicalHFIndex(
+ vector_size=config.retrieval_vector_size,
+ dataset_name=config.dataset,
+ dataset_split=config.dataset_split,
+ index_name=config.index_name,
+ index_path=config.index_path,
+ use_dummy_dataset=config.use_dummy_dataset,
+ dataset_revision=config.dataset_revision,
+ )
+
+ @classmethod
+ def from_pretrained(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
+ requires_backends(cls, ["datasets"])
+ config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
+ rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
+ question_encoder_tokenizer = rag_tokenizer.question_encoder
+ generator_tokenizer = rag_tokenizer.generator
+ if indexed_dataset is not None:
+ config.index_name = "custom"
+ index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
+ else:
+ index = cls._build_index(config)
+ return cls(
+ config,
+ question_encoder_tokenizer=question_encoder_tokenizer,
+ generator_tokenizer=generator_tokenizer,
+ index=index,
+ )
+
+ def save_pretrained(self, save_directory):
+ if isinstance(self.index, CustomHFIndex):
+ if self.config.index_path is None:
+ index_path = os.path.join(save_directory, "hf_dataset_index.faiss")
+ self.index.dataset.get_index("embeddings").save(index_path)
+ self.config.index_path = index_path
+ if self.config.passages_path is None:
+ passages_path = os.path.join(save_directory, "hf_dataset")
+ # datasets don't support save_to_disk with indexes right now
+ faiss_index = self.index.dataset._indexes.pop("embeddings")
+ self.index.dataset.save_to_disk(passages_path)
+ self.index.dataset._indexes["embeddings"] = faiss_index
+ self.config.passages_path = passages_path
+ self.config.save_pretrained(save_directory)
+ rag_tokenizer = RagTokenizer(
+ question_encoder=self.question_encoder_tokenizer,
+ generator=self.generator_tokenizer,
+ )
+ rag_tokenizer.save_pretrained(save_directory)
+
+ def init_retrieval(self):
+ """
+ Retriever initialization function. It loads the index into memory.
+ """
+
+ logger.info("initializing retrieval")
+ self.index.init_index()
+
+ def postprocess_docs(self, docs, input_strings, prefix, n_docs, return_tensors=None):
+ r"""
+ Postprocessing retrieved `docs` and combining them with `input_strings`.
+
+ Args:
+ docs (`dict`):
+ Retrieved documents.
+ input_strings (`str`):
+ Input strings decoded by `preprocess_query`.
+ prefix (`str`):
+ Prefix added at the beginning of each input, typically used with T5-based models.
+
+ Return:
+ `tuple(tensors)`: a tuple consisting of two elements: contextualized `input_ids` and a compatible
+ `attention_mask`.
+ """
+
+ def cat_input_and_doc(doc_title, doc_text, input_string, prefix):
+ # TODO(Patrick): if we train more RAG models, I want to put the input first to take advantage of effortless truncation
+ # TODO(piktus): better handling of truncation
+ doc_title = doc_title.removeprefix('"').removesuffix('"')
+ if prefix is None:
+ prefix = ""
+ out = (prefix + doc_title + self.config.title_sep + doc_text + self.config.doc_sep + input_string).replace(
+ " ", " "
+ )
+ return out
+
+ rag_input_strings = [
+ cat_input_and_doc(
+ docs[i]["title"][j],
+ docs[i]["text"][j],
+ input_strings[i],
+ prefix,
+ )
+ for i in range(len(docs))
+ for j in range(n_docs)
+ ]
+
+ contextualized_inputs = self.generator_tokenizer.batch_encode_plus(
+ rag_input_strings,
+ max_length=self.config.max_combined_length,
+ return_tensors=return_tensors,
+ padding="max_length",
+ truncation=True,
+ )
+
+ return contextualized_inputs["input_ids"], contextualized_inputs["attention_mask"]
+
+ def _chunk_tensor(self, t: Iterable, chunk_size: int) -> list[Iterable]:
+ return [t[i : i + chunk_size] for i in range(0, len(t), chunk_size)]
+
+ def _main_retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> tuple[np.ndarray, np.ndarray]:
+ question_hidden_states_batched = self._chunk_tensor(question_hidden_states, self.batch_size)
+ ids_batched = []
+ vectors_batched = []
+ for question_hidden_states in question_hidden_states_batched:
+ start_time = time.time()
+ ids, vectors = self.index.get_top_docs(question_hidden_states, n_docs)
+ logger.debug(
+ f"index search time: {time.time() - start_time} sec, batch size {question_hidden_states.shape}"
+ )
+ ids_batched.extend(ids)
+ vectors_batched.extend(vectors)
+ return (
+ np.array(ids_batched),
+ np.array(vectors_batched),
+ ) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
+
+ def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> tuple[np.ndarray, np.ndarray, list[dict]]:
+ """
+ Retrieves documents for specified `question_hidden_states`.
+
+ Args:
+ question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`):
+ A batch of query vectors to retrieve with.
+ n_docs (`int`):
+ The number of docs retrieved per query.
+
+ Return:
+ `tuple[np.ndarray, np.ndarray, list[dict]]`: A tuple with the following objects:
+
+ - **retrieved_doc_embeds** (`np.ndarray` of shape `(batch_size, n_docs, dim)`) -- The retrieval embeddings
+ of the retrieved docs per query.
+ - **doc_ids** (`np.ndarray` of shape `(batch_size, n_docs)`) -- The ids of the documents in the index
+ - **doc_dicts** (`list[dict]`): The `retrieved_doc_embeds` examples per query.
+ """
+
+ doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
+ return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
+
+ def set_ctx_encoder_tokenizer(self, ctx_encoder_tokenizer: PreTrainedTokenizer):
+ # used in end2end retriever training
+ self.ctx_encoder_tokenizer = ctx_encoder_tokenizer
+ self.return_tokenized_docs = True
+
+ def __call__(
+ self,
+ question_input_ids: list[list[int]],
+ question_hidden_states: np.ndarray,
+ prefix=None,
+ n_docs=None,
+ return_tensors=None,
+ ) -> BatchEncoding:
+ """
+ Retrieves documents for specified `question_hidden_states`.
+
+ Args:
+ question_input_ids (`list[list[int]]`) batch of input ids
+ question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`:
+ A batch of query vectors to retrieve with.
+ prefix (`str`, *optional*):
+ The prefix used by the generator's tokenizer.
+ n_docs (`int`, *optional*):
+ The number of docs retrieved per query.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to "pt"):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+
+ Returns: [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+
+ - **context_input_ids** -- List of token ids to be fed to a model.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ - **context_attention_mask** -- List of indices specifying which tokens should be attended to by the model
+ (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`).
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ - **retrieved_doc_embeds** -- List of embeddings of the retrieved documents
+ - **doc_ids** -- List of ids of the retrieved documents
+ """
+
+ n_docs = n_docs if n_docs is not None else self.n_docs
+ prefix = prefix if prefix is not None else self.config.generator.prefix
+ retrieved_doc_embeds, doc_ids, docs = self.retrieve(question_hidden_states, n_docs)
+
+ input_strings = self.question_encoder_tokenizer.batch_decode(question_input_ids, skip_special_tokens=True)
+ context_input_ids, context_attention_mask = self.postprocess_docs(
+ docs, input_strings, prefix, n_docs, return_tensors=return_tensors
+ )
+
+ if self.return_tokenized_docs:
+ retrieved_doc_text = []
+ retrieved_doc_title = []
+
+ for b_idx in range(len(docs)):
+ for doc_idx in range(n_docs):
+ retrieved_doc_text.append(docs[b_idx]["text"][doc_idx])
+ retrieved_doc_title.append(docs[b_idx]["title"][doc_idx])
+
+ tokenized_docs = self.ctx_encoder_tokenizer(
+ retrieved_doc_title,
+ retrieved_doc_text,
+ truncation=True,
+ padding="longest",
+ return_tensors=return_tensors,
+ )
+
+ return BatchEncoding(
+ {
+ "context_input_ids": context_input_ids,
+ "context_attention_mask": context_attention_mask,
+ "retrieved_doc_embeds": retrieved_doc_embeds,
+ "doc_ids": doc_ids,
+ "tokenized_doc_ids": tokenized_docs["input_ids"],
+ "tokenized_doc_attention_mask": tokenized_docs["attention_mask"],
+ },
+ tensor_type=return_tensors,
+ )
+
+ else:
+ return BatchEncoding(
+ {
+ "context_input_ids": context_input_ids,
+ "context_attention_mask": context_attention_mask,
+ "retrieved_doc_embeds": retrieved_doc_embeds,
+ "doc_ids": doc_ids,
+ },
+ tensor_type=return_tensors,
+ )
+
+
+__all__ = ["RagRetriever"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/tokenization_rag.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/tokenization_rag.py
new file mode 100644
index 0000000000000000000000000000000000000000..217dd2d82df6ff48af8f7952b40f8088fb056b9d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rag/tokenization_rag.py
@@ -0,0 +1,124 @@
+# coding=utf-8
+# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for RAG."""
+
+import os
+import warnings
+from typing import Optional
+
+from ...tokenization_utils_base import BatchEncoding
+from ...utils import logging
+from .configuration_rag import RagConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class RagTokenizer:
+ def __init__(self, question_encoder, generator):
+ self.question_encoder = question_encoder
+ self.generator = generator
+ self.current_tokenizer = self.question_encoder
+
+ def save_pretrained(self, save_directory):
+ if os.path.isfile(save_directory):
+ raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
+ os.makedirs(save_directory, exist_ok=True)
+ question_encoder_path = os.path.join(save_directory, "question_encoder_tokenizer")
+ generator_path = os.path.join(save_directory, "generator_tokenizer")
+ self.question_encoder.save_pretrained(question_encoder_path)
+ self.generator.save_pretrained(generator_path)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ # dynamically import AutoTokenizer
+ from ..auto.tokenization_auto import AutoTokenizer
+
+ config = kwargs.pop("config", None)
+
+ if config is None:
+ config = RagConfig.from_pretrained(pretrained_model_name_or_path)
+
+ question_encoder = AutoTokenizer.from_pretrained(
+ pretrained_model_name_or_path, config=config.question_encoder, subfolder="question_encoder_tokenizer"
+ )
+ generator = AutoTokenizer.from_pretrained(
+ pretrained_model_name_or_path, config=config.generator, subfolder="generator_tokenizer"
+ )
+ return cls(question_encoder=question_encoder, generator=generator)
+
+ def __call__(self, *args, **kwargs):
+ return self.current_tokenizer(*args, **kwargs)
+
+ def batch_decode(self, *args, **kwargs):
+ return self.generator.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ return self.generator.decode(*args, **kwargs)
+
+ def _switch_to_input_mode(self):
+ self.current_tokenizer = self.question_encoder
+
+ def _switch_to_target_mode(self):
+ self.current_tokenizer = self.generator
+
+ def prepare_seq2seq_batch(
+ self,
+ src_texts: list[str],
+ tgt_texts: Optional[list[str]] = None,
+ max_length: Optional[int] = None,
+ max_target_length: Optional[int] = None,
+ padding: str = "longest",
+ return_tensors: Optional[str] = None,
+ truncation: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ warnings.warn(
+ "`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the "
+ "regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` "
+ "context manager to prepare your targets. See the documentation of your specific tokenizer for more "
+ "details",
+ FutureWarning,
+ )
+ if max_length is None:
+ max_length = self.current_tokenizer.model_max_length
+ model_inputs = self(
+ src_texts,
+ add_special_tokens=True,
+ return_tensors=return_tensors,
+ max_length=max_length,
+ padding=padding,
+ truncation=truncation,
+ **kwargs,
+ )
+ if tgt_texts is None:
+ return model_inputs
+ # Process tgt_texts
+ if max_target_length is None:
+ max_target_length = self.current_tokenizer.model_max_length
+ labels = self(
+ text_target=tgt_texts,
+ add_special_tokens=True,
+ return_tensors=return_tensors,
+ padding=padding,
+ max_length=max_target_length,
+ truncation=truncation,
+ **kwargs,
+ )
+ model_inputs["labels"] = labels["input_ids"]
+ return model_inputs
+
+
+__all__ = ["RagTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/recurrent_gemma/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/recurrent_gemma/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab9335fc4beff8b178336ccb546bdecb4cd45171
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/recurrent_gemma/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_recurrent_gemma import *
+ from .modeling_recurrent_gemma import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef2a08699123a5f6e6ef8d226f299904fdb7d445
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py
@@ -0,0 +1,161 @@
+# coding=utf-8
+# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""RecurrentGemma model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class RecurrentGemmaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`RecurrentGemmaModel`]. It is used to instantiate a RecurrentGemma
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the RecurrentGemma-7B.
+
+ e.g. [google/recurrentgemma-2b](https://huggingface.co/google/recurrentgemma-2b)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ num_hidden_layers (`int`, *optional*, defaults to 26):
+ The number of hidden layers in the model.
+ vocab_size (`int`, *optional*, defaults to 256000):
+ Vocabulary size of the RecurrentGemma model. Defines the number of
+ different tokens that can be represented by the
+ `inputs_ids` passed when calling [`RecurrentGemmaModel`]
+ hidden_size (`int`, *optional*, defaults to 2560):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 7680):
+ Dimension of the MLP representations.
+ num_attention_heads (`int`, *optional*, defaults to 10):
+ The number of heads for the attention block and the number of
+ heads/blocks for the block-diagonal layers used in the RG-LRU gates.
+ This number must divide `hidden_size` and `lru_width`.
+ lru_width (`int` or `None`, *optional*):
+ Dimension of the hidden representations of the RG-LRU. If `None`
+ this will be set to `hidden_size`.
+ Whether to scale the output of the embeddings by `sqrt(hidden_size)`.
+ attention_window_size (`int`, *optional*, defaults to 2048):
+ The size of the attention window used in the attention block.
+ conv1d_width (`int`, *optional*, defaults to 4):
+ The kernel size of conv1d layers used in the recurrent blocks.
+ logits_soft_cap (`float`, *optional*, defaults to 30.0):
+ The value at which the logits should be soft-capped to after the transformer and LM-head computation in the Causal LM architecture.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether the model should return the last key/values
+ attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 1):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 2):
+ Beginning of stream token id.
+ hidden_activation (``str` or `function``, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The hidden activation used in the recurrent block as well as the MLP layer of the decoder layers.
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5):
+ The partial rotary factor used in the initialization of the rotary embeddings.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ block_types (`list[str]`, *optional*, defaults to `('recurrent', 'recurrent', 'attention')`):
+ List of aleternating blocks that will be repeated to initialize the `temporal_block` layer.
+ attention_dropout (`float`, *optional*, defaults to 0.0): dropout value to use after the attention softmax.
+ num_key_value_heads (`16`, *optional*, defaults to 16): Number of key value heads to use GQA.
+ attention_bias (`bool`, *optional*, defaults to `False`): whether or not the linear q,k,v of the Attention layer should have bias
+ w_init_variance_scale (`float`, *optional*, defaults to 0.01): weight initialization variance.
+ ```python
+ >>> from transformers import RecurrentGemmaModel, RecurrentGemmaConfig
+
+ >>> # Initializing a RecurrentGemma recurrentgemma-2b style configuration
+ >>> configuration = RecurrentGemmaConfig()
+
+ >>> # Initializing a model from the recurrentgemma-2b style configuration
+ >>> model = RecurrentGemmaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "recurrent_gemma"
+
+ def __init__(
+ self,
+ num_hidden_layers=26,
+ vocab_size=256000,
+ hidden_size=2560,
+ intermediate_size=3 * 2560,
+ num_attention_heads=10,
+ lru_width=None,
+ attention_window_size=2048,
+ conv1d_width=4,
+ logits_soft_cap=30.0,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=0,
+ eos_token_id=1,
+ bos_token_id=2,
+ hidden_activation="gelu_pytorch_tanh",
+ partial_rotary_factor=0.5,
+ rope_theta=10000.0,
+ block_types=("recurrent", "recurrent", "attention"),
+ attention_dropout=0.0,
+ num_key_value_heads=None,
+ attention_bias=False,
+ w_init_variance_scale=0.01,
+ **kwargs,
+ ):
+ self.num_hidden_layers = num_hidden_layers
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.lru_width = lru_width if lru_width is not None else hidden_size
+ self.attention_window_size = attention_window_size
+ self.conv1d_width = conv1d_width
+ self.logits_soft_cap = logits_soft_cap
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.partial_rotary_factor = partial_rotary_factor
+ self.block_types = list(block_types)
+ self.hidden_activation = hidden_activation
+ self.head_dim = self.hidden_size // self.num_attention_heads
+ self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
+ if self.num_key_value_heads > self.num_attention_heads:
+ raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
+ self.attention_dropout = attention_dropout
+ self.attention_bias = attention_bias
+ self.w_init_variance_scale = w_init_variance_scale
+ self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+
+ @property
+ def layers_block_type(self):
+ return (self.block_types * 100)[: self.num_hidden_layers]
+
+
+__all__ = ["RecurrentGemmaConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
new file mode 100644
index 0000000000000000000000000000000000000000..88364515459af62e3bdc74d41c9c1dfe97c88ba2
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
@@ -0,0 +1,785 @@
+# coding=utf-8
+# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch RecurrentGemma model."""
+
+import math
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging
+from ...utils.import_utils import is_torchdynamo_compiling
+from .configuration_recurrent_gemma import RecurrentGemmaConfig
+
+
+logger = logging.get_logger(__name__)
+_MAX_SQRT_GRADIENT = 1000.0
+
+
+# Copied from transformers.models.gemma.modeling_gemma.GemmaRMSNorm with Gemma->RecurrentGemma
+class RecurrentGemmaRMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.zeros(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ output = self._norm(x.float())
+ # Llama does x.to(float16) * w whilst RecurrentGemma is (x * w).to(float16)
+ # See https://github.com/huggingface/transformers/pull/29402
+ output = output * (1.0 + self.weight.float())
+ return output.type_as(x)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
+
+
+class RecurrentGemmaRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim, base=10000, device=None):
+ super().__init__()
+ self.dim = dim
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ self.inv_freq.to(x.device)
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class RecurrentGemmaSdpaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: RecurrentGemmaConfig):
+ super().__init__()
+ self.config = config
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_attention_heads = config.num_attention_heads
+ self.head_dim = config.head_dim
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
+ self.partial_rotary_factor = config.partial_rotary_factor
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
+ self.rotary_emb = RecurrentGemmaRotaryEmbedding(
+ int(self.partial_rotary_factor * self.head_dim),
+ base=config.rope_theta,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: bool = False,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+
+ # Partial rotary embedding
+ query_rot, query_pass = torch.chunk(query_states, int(1 / self.partial_rotary_factor), dim=-1)
+ key_rot, key_pass = torch.chunk(key_states, int(1 / self.partial_rotary_factor), dim=-1)
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
+
+ if use_cache and hasattr(self, "key_states"):
+ cache_kwargs = {"cache_position": cache_position}
+ key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states.contiguous(),
+ key_states.contiguous(),
+ value_states.contiguous(),
+ attn_mask=causal_mask, # pretty much a must for sliding window backend!
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ scale=self.head_dim**-0.5,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+ return attn_output
+
+ def _setup_cache(self, batch_size, device, dtype=None):
+ if dtype is None and self.config.dtype is not None:
+ dtype = self.config.dtype
+ dtype = dtype if dtype is not None else torch.float32
+ cache_shape = (batch_size, self.num_key_value_heads, self.config.attention_window_size, self.head_dim)
+ self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)
+ self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)
+
+ @torch.no_grad()
+ def _update_cache(self, key_states, value_states, **cache_kwargs):
+ """
+ torch.compile compatible sliding window.
+ Computes the `indices` based on `cache_position >= self.config.attention_window_size - 1`.
+ The `to_shift` is only true once we are above attention_window_size. Thus with `attention_window_size==64`:
+
+ indices = (slicing + to_shift[-1].int()-1) % self.config.attention_window_size
+ tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
+ 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
+ 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
+ 55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
+
+ We overwrite the cache using these, then we always write at cache_position (clamped to `attention_window_size`)
+ """
+ cache_position = cache_kwargs.get("cache_position")
+ if cache_position.shape[0] > self.config.attention_window_size:
+ # int indexing -> device sync? in compile, use tensor
+ k_out = key_states[:, :, -self.config.attention_window_size :, :]
+ v_out = value_states[:, :, -self.config.attention_window_size :, :]
+ else:
+ slicing = torch.ones(
+ self.config.attention_window_size, dtype=torch.long, device=value_states.device
+ ).cumsum(0)
+ cache_position = cache_position.clamp(0, self.config.attention_window_size - 1)
+ to_shift = cache_position >= self.config.attention_window_size - 1
+ indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size
+
+ k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
+ k_out = k_out[:, :, indices]
+ v_out = v_out[:, :, indices]
+
+ k_out[:, :, cache_position] = key_states.to(k_out.dtype)
+ v_out[:, :, cache_position] = value_states.to(v_out.dtype)
+
+ self.key_states, self.value_states = k_out, v_out
+ return k_out, v_out
+
+
+class SqrtBoundDerivative(torch.autograd.Function):
+ """Computes a square root with a gradient clipped at `_MAX_SQRT_GRADIENT`."""
+
+ @staticmethod
+ def forward(ctx, x: torch.Tensor) -> torch.Tensor:
+ """The forward pass, which is a normal `sqrt`."""
+ ctx.save_for_backward(x)
+ return torch.sqrt(x)
+
+ @staticmethod
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
+ """The backward pass, which clips the `sqrt` gradient."""
+ (x,) = ctx.saved_tensors
+ clipped_x_times_4 = torch.clip(4.0 * x, min=1 / (_MAX_SQRT_GRADIENT**2))
+ return grad_output / torch.sqrt(clipped_x_times_4)
+
+
+class RecurrentGemmaRglru(nn.Module):
+ """A Real-Gated Linear Recurrent Unit (RG-LRU) layer."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.num_attention_heads = config.num_attention_heads
+ self.block_width = config.lru_width // self.num_attention_heads
+
+ self.recurrent_param = nn.Parameter(torch.empty([config.lru_width]))
+ self.input_gate_weight = nn.Parameter(
+ torch.empty([self.num_attention_heads, self.block_width, self.block_width])
+ )
+ self.input_gate_bias = nn.Parameter(torch.empty([self.num_attention_heads, self.block_width]))
+
+ self.recurrent_gate_weight = nn.Parameter(
+ torch.empty([self.num_attention_heads, self.block_width, self.block_width])
+ )
+ self.recurrent_gate_bias = nn.Parameter(torch.empty([self.num_attention_heads, self.block_width]))
+ self.recurrent_states = None
+
+ def forward(
+ self,
+ activations: torch.Tensor,
+ position_ids: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size, seq_len, lru_width = activations.shape
+ reset = position_ids[:, :, None] == 0
+
+ reshape_act = activations.reshape(batch_size * seq_len, self.num_attention_heads, self.block_width)
+ reshape_act = reshape_act.permute(1, 0, 2)
+
+ res = torch.baddbmm(self.input_gate_bias[:, None, :], reshape_act, self.input_gate_weight)
+ input_gate = torch.sigmoid(res.transpose(0, 1).reshape(batch_size, seq_len, lru_width))
+
+ res = torch.baddbmm(self.recurrent_gate_bias[:, None, :], reshape_act, self.recurrent_gate_weight)
+ recurrent_gate = torch.sigmoid(res.transpose(0, 1).reshape(batch_size, seq_len, lru_width))
+
+ # Compute the parameter `A` of the recurrence.
+ log_recurrent_gate = -8.0 * recurrent_gate * nn.functional.softplus(self.recurrent_param)
+ recurrent_gate = torch.exp(log_recurrent_gate)
+ a_square = torch.exp(2 * log_recurrent_gate)
+
+ # Gate the input.
+ gated_inputs = activations * input_gate
+
+ # Apply gamma normalization to the input. We need to clip the derivatives of
+ # `sqrt` in order to prevent NaNs during training in bfloat16. TODO a bit annoying
+ multiplier = 1
+ tracing = isinstance(activations, torch.fx.Proxy) or is_torchdynamo_compiling()
+ if not torch.jit.is_tracing() and not tracing:
+ multiplier = SqrtBoundDerivative.apply(1 - a_square)
+ multiplier = reset + ~reset * multiplier
+ normalized_x = gated_inputs * multiplier.type(activations.dtype)
+
+ hidden_states, recurrent_states = self._rnn_scan(
+ hidden_states=normalized_x,
+ recurrent_gate=recurrent_gate,
+ reset=reset,
+ recurrent_states=self.recurrent_states,
+ )
+ self.recurrent_states = recurrent_states
+ return hidden_states
+
+ # TODO refactor
+ def _rnn_scan(
+ self,
+ hidden_states: torch.Tensor,
+ recurrent_gate: torch.Tensor,
+ reset: torch.Tensor,
+ recurrent_states: Union[torch.Tensor, None],
+ acc_dtype: torch.dtype = torch.float32,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Runs the recurrence of a linear RNN.
+
+ Args:
+ hidden_states: The input sequence.
+ recurrent_gate: The diagonal of the recurrence matrix `A`.
+ reset: Indicator of document boundaries, e.g. when to reset the hidden state
+ of the RNN.
+ recurrent_states: The initial hidden state.
+ acc_dtype: The data type for the accumulation.
+
+ Returns:
+ The output of the linear recurrence.
+ """
+ # Multiply `a` by the reset.
+ recurrent_gate = recurrent_gate * ~reset
+
+ if hidden_states.shape[1] == 1:
+ # Using scan in sampling mode.
+ if recurrent_states is None: # same here, when decoding you always have cache
+ return hidden_states, hidden_states[:, 0].type(acc_dtype)
+
+ else:
+ contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None].to(
+ recurrent_gate.device
+ )
+ contextualized_states += hidden_states.type(acc_dtype)
+ return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1]
+
+ else:
+ # Using scan in linear mode.
+ if recurrent_states is None:
+ recurrent_states = torch.zeros(hidden_states[:, 0].shape, dtype=acc_dtype, device=hidden_states.device)
+
+ contextualized_states = torch.zeros_like(hidden_states)
+ for t in range(hidden_states.shape[1]):
+ recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states.to(recurrent_gate.device)
+ recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype)
+ contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype)
+
+ return contextualized_states, recurrent_states
+
+
+class RecurrentGemmaRecurrentBlock(nn.Module):
+ """Griffin and Hawk's recurrent block."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.lru_width = config.lru_width
+ self.hidden_size = config.hidden_size
+ self.linear_y = nn.Linear(in_features=config.hidden_size, out_features=config.lru_width)
+ self.linear_x = nn.Linear(in_features=config.hidden_size, out_features=config.lru_width)
+ self.linear_out = nn.Linear(in_features=config.lru_width, out_features=config.hidden_size)
+ self.conv1d_width = config.conv1d_width
+ self.conv_1d = nn.Conv1d(
+ config.lru_width,
+ config.lru_width,
+ kernel_size=config.conv1d_width,
+ groups=config.lru_width,
+ padding=config.conv1d_width - 1,
+ )
+ self.rg_lru = RecurrentGemmaRglru(config)
+ self.act_fn = ACT2FN[config.hidden_activation]
+
+ self.conv1d_state = None
+
+ def forward(
+ self,
+ input_states: torch.Tensor,
+ position_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ cache_position: torch.Tensor,
+ use_cache: bool = True,
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
+ _, seq_len, _ = input_states.shape
+
+ y_branch = self.linear_y(input_states)
+ y_branch = self.act_fn(y_branch)
+
+ x_branch = self.linear_x(input_states)
+ x_branch = x_branch.transpose(1, 2)
+
+ if use_cache:
+ if cache_position.shape[0] != 1: # prefill
+ self.conv1d_state = nn.functional.pad(x_branch, (self.conv1d_width - x_branch.shape[-1] - 1, 0))
+ x_branch = self.conv_1d(x_branch)[..., :seq_len]
+ else: # decoding
+ conv_state = torch.cat((self.conv1d_state, x_branch), -1)
+ x_branch = torch.sum(conv_state * self.conv_1d.weight[:, 0, :], dim=-1) + self.conv_1d.bias
+ x_branch = x_branch.unsqueeze(-1)
+ self.conv1d_state = conv_state[:, :, 1:]
+ else:
+ x_branch = self.conv_1d(x_branch)[..., :seq_len]
+
+ x_branch = self.rg_lru(x_branch.transpose(1, 2), position_ids)
+
+ hidden_states = x_branch * y_branch
+ hidden_states = self.linear_out(hidden_states)
+ return hidden_states
+
+ def _setup_cache(self, batch, device, dtype):
+ # recurrent_states always computed in full precision
+ self.rg_lru.recurrent_states = torch.zeros((batch, self.lru_width), device=device, dtype=torch.float32)
+ self.conv1d_state = torch.zeros((batch, self.hidden_size, self.conv1d_width - 1), device=device, dtype=dtype)
+
+
+TEMPORAL_BLOCK_CLASSES = {"recurrent": RecurrentGemmaRecurrentBlock, "attention": RecurrentGemmaSdpaAttention}
+
+
+class RecurrentGemmaMlp(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size // 2
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
+ self.act_fn = ACT2FN[config.hidden_activation]
+
+ def forward(self, hidden_states):
+ gate = self.act_fn(self.gate_proj(hidden_states))
+ return self.down_proj(gate * self.up_proj(hidden_states))
+
+
+class RecurrentGemmaDecoderLayer(GradientCheckpointingLayer):
+ """Griffin and Hawk's residual block."""
+
+ def __init__(self, config, layer_idx):
+ super().__init__()
+ self.temporal_pre_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.temporal_block = TEMPORAL_BLOCK_CLASSES[config.layers_block_type[layer_idx]](config)
+ self.channel_pre_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.mlp_block = RecurrentGemmaMlp(config)
+
+ def forward(
+ self,
+ activations: torch.Tensor,
+ position_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ cache_position: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
+ raw_activations = activations
+ inputs_normalized = self.temporal_pre_norm(raw_activations) # RMSNorm introduces slight slight differences
+
+ hidden_states = self.temporal_block(
+ inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache
+ )
+
+ residual = hidden_states + raw_activations
+
+ hidden_states = self.channel_pre_norm(residual)
+ hidden_states = self.mlp_block(hidden_states)
+
+ hidden_states = hidden_states + residual
+ return hidden_states
+
+
+@auto_docstring
+class RecurrentGemmaPreTrainedModel(PreTrainedModel):
+ config: RecurrentGemmaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["RecurrentGemmaDecoderLayer"]
+ _skip_keys_device_placement = ["cache"]
+ _supports_flash_attn = False
+ _supports_sdpa = False # we can't compare with eager for now
+
+ def _init_weights(self, module):
+ std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width)
+ if isinstance(module, nn.Conv1d):
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
+ torch.nn.init.zeros_(module.bias)
+ elif isinstance(module, RecurrentGemmaSdpaAttention):
+ torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size))
+ torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size))
+ torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size))
+
+ std = math.sqrt(self.config.final_w_init_variance_scale / self.config.hidden_size)
+ torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=std)
+ elif isinstance(module, RecurrentGemmaRecurrentBlock):
+ torch.nn.init.zeros_(module.linear_x.bias)
+ torch.nn.init.normal_(module.linear_x.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size))
+
+ torch.nn.init.zeros_(module.linear_y.bias)
+ torch.nn.init.normal_(module.linear_y.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size))
+
+ std = math.sqrt(self.config.final_w_init_variance_scale / self.config.lru_width)
+ torch.nn.init.normal_(module.linear_out.weight, mean=0.0, std=std)
+ torch.nn.init.zeros_(module.linear_out.bias)
+ elif isinstance(module, RecurrentGemmaRglru):
+ std = math.sqrt(
+ self.config.w_init_variance_scale / (self.config.lru_width // self.config.num_attention_heads)
+ )
+ torch.nn.init.normal_(module.input_gate_weight, mean=0.0, std=std)
+ torch.nn.init.normal_(module.recurrent_gate_weight, mean=0.0, std=std)
+ torch.nn.init.zeros_(module.input_gate_bias)
+ torch.nn.init.zeros_(module.recurrent_gate_bias)
+
+ module.recurrent_param.data.uniform_(0.9**2 + 1e-8, 0.999**2 + 1e-8)
+ module.recurrent_param.data.log_().mul_(0.5)
+ module.recurrent_param.data.neg_().exp_().sub_(1.0).log_()
+ elif isinstance(module, nn.Linear):
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
+ if getattr(module, "bias", None) is not None:
+ torch.nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
+ elif isinstance(module, RecurrentGemmaRMSNorm):
+ module.weight.data.zero_()
+
+ def _setup_cache(self, config, batch, device, dtype):
+ layers = getattr(self, "model", self).layers
+ for layer in layers:
+ layer.temporal_block._setup_cache(batch, device, dtype)
+
+ def reset_cache(self, batch, device, dtype):
+ pass
+
+
+@auto_docstring
+class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel):
+ def __init__(self, config: RecurrentGemmaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [RecurrentGemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.final_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.gradient_checkpointing = False
+
+ self.register_buffer(
+ "normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16), persistent=False
+ )
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithNoAttention]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ hidden_states = inputs_embeds
+
+ if use_cache and inputs_embeds.shape[1] != 1: # TODO let's maybe only call in the `generate`?
+ self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype)
+
+ if cache_position is None:
+ cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
+
+ hidden_states = hidden_states * self.normalizer.type(hidden_states.dtype)
+
+ all_hidden_states = () if output_hidden_states else None
+ for i, residual_block in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ hidden_states = residual_block(hidden_states, position_ids, causal_mask, cache_position, use_cache)
+
+ hidden_states = self.final_norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+
+ return BaseModelOutputWithNoAttention(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ )
+
+ # Ignore copy
+ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ target_length = max(self.config.attention_window_size, sequence_length)
+
+ diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
+ causal_mask = diagonal
+ if sequence_length != 1:
+ causal_mask = torch.triu(diagonal, diagonal=-1)
+
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ if attention_mask.dim() == 2:
+ # Crop the attention mask to the target length.
+ attention_mask = attention_mask[:, -target_length:]
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
+
+ if attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"]:
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+
+# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma
+@auto_docstring
+class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = RecurrentGemmaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ # Ignore copy
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, RecurrentGemmaForCausalLM
+
+ >>> model = RecurrentGemmaForCausalLM.from_pretrained("google/recurrentgemma-2b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/recurrentgemma-2b")
+
+ >>> prompt = "What is your favorite condiment?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "What is your favorite condiment?"
+ ```"""
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True
+ outputs = self.model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+
+ # Soft-cap the logits TODO remove if always done.
+ # if self.config.logits_soft_cap is not None:
+ cap = self.config.logits_soft_cap
+ logits = nn.functional.tanh(logits / cap) * cap
+
+ loss = None
+ if labels is not None:
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ loss = self.loss_function(
+ logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ )
+
+
+__all__ = ["RecurrentGemmaForCausalLM", "RecurrentGemmaModel", "RecurrentGemmaPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd1fac5b679dc5ca527dcc0a6a45b9a873daa3dc
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/__init__.py
@@ -0,0 +1,33 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_rt_detr import *
+ from .configuration_rt_detr_resnet import *
+ from .image_processing_rt_detr import *
+ from .image_processing_rt_detr_fast import *
+ from .modeling_rt_detr import *
+ from .modeling_rt_detr_resnet import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/configuration_rt_detr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/configuration_rt_detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..994d4a6fd6f02a9c11817c3ad1bdcdd83a14bb24
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/configuration_rt_detr.py
@@ -0,0 +1,372 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""RT-DETR model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import verify_backbone_config_arguments
+from ..auto import CONFIG_MAPPING
+from .configuration_rt_detr_resnet import RTDetrResNetConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class RTDetrConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`RTDetrModel`]. It is used to instantiate a
+ RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the RT-DETR
+ [PekingU/rtdetr_r50vd](https://huggingface.co/PekingU/rtdetr_r50vd) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ initializer_range (`float`, *optional*, defaults to 0.01):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_bias_prior_prob (`float`, *optional*):
+ The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
+ If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ batch_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the batch normalization layers.
+ backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
+ The configuration of the backbone model.
+ backbone (`str`, *optional*):
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
+ Whether to use pretrained weights for the backbone.
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
+ library.
+ freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
+ Whether to freeze the batch normalization layers in the backbone.
+ backbone_kwargs (`dict`, *optional*):
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
+ encoder_hidden_dim (`int`, *optional*, defaults to 256):
+ Dimension of the layers in hybrid encoder.
+ encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
+ Multi level features input for encoder.
+ feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`):
+ Strides used in each feature map.
+ encoder_layers (`int`, *optional*, defaults to 1):
+ Total of layers to be used by the encoder.
+ encoder_ffn_dim (`int`, *optional*, defaults to 1024):
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The ratio for all dropout layers.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`):
+ Indexes of the projected layers to be used in the encoder.
+ positional_encoding_temperature (`int`, *optional*, defaults to 10000):
+ The temperature parameter used to create the positional encodings.
+ encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ activation_function (`str`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ eval_size (`tuple[int, int]`, *optional*):
+ Height and width used to computes the effective height and width of the position embeddings after taking
+ into account the stride.
+ normalize_before (`bool`, *optional*, defaults to `False`):
+ Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
+ feed-forward modules.
+ hidden_expansion (`float`, *optional*, defaults to 1.0):
+ Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
+ d_model (`int`, *optional*, defaults to 256):
+ Dimension of the layers exclude hybrid encoder.
+ num_queries (`int`, *optional*, defaults to 300):
+ Number of object queries.
+ decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
+ Multi level features dimension for decoder
+ decoder_ffn_dim (`int`, *optional*, defaults to 1024):
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
+ num_feature_levels (`int`, *optional*, defaults to 3):
+ The number of input feature levels.
+ decoder_n_points (`int`, *optional*, defaults to 4):
+ The number of sampled keys in each feature level for each attention head in the decoder.
+ decoder_layers (`int`, *optional*, defaults to 6):
+ Number of decoder layers.
+ decoder_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
+ The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ num_denoising (`int`, *optional*, defaults to 100):
+ The total number of denoising tasks or queries to be used for contrastive denoising.
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
+ The fraction of denoising labels to which random noise should be added.
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
+ Scale or magnitude of noise to be added to the bounding boxes.
+ learn_initial_query (`bool`, *optional*, defaults to `False`):
+ Indicates whether the initial query embeddings for the decoder should be learned during training
+ anchor_image_size (`tuple[int, int]`, *optional*):
+ Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
+ disable_custom_kernels (`bool`, *optional*, defaults to `True`):
+ Whether to disable custom kernels.
+ with_box_refine (`bool`, *optional*, defaults to `True`):
+ Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
+ based on the predictions from the previous layer.
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+ Whether the architecture has an encoder decoder structure.
+ matcher_alpha (`float`, *optional*, defaults to 0.25):
+ Parameter alpha used by the Hungarian Matcher.
+ matcher_gamma (`float`, *optional*, defaults to 2.0):
+ Parameter gamma used by the Hungarian Matcher.
+ matcher_class_cost (`float`, *optional*, defaults to 2.0):
+ The relative weight of the class loss used by the Hungarian Matcher.
+ matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
+ The relative weight of the bounding box loss used by the Hungarian Matcher.
+ matcher_giou_cost (`float`, *optional*, defaults to 2.0):
+ The relative weight of the giou loss of used by the Hungarian Matcher.
+ use_focal_loss (`bool`, *optional*, defaults to `True`):
+ Parameter informing if focal focal should be used.
+ auxiliary_loss (`bool`, *optional*, defaults to `True`):
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
+ focal_loss_alpha (`float`, *optional*, defaults to 0.75):
+ Parameter alpha used to compute the focal loss.
+ focal_loss_gamma (`float`, *optional*, defaults to 2.0):
+ Parameter gamma used to compute the focal loss.
+ weight_loss_vfl (`float`, *optional*, defaults to 1.0):
+ Relative weight of the varifocal loss in the object detection loss.
+ weight_loss_bbox (`float`, *optional*, defaults to 5.0):
+ Relative weight of the L1 bounding box loss in the object detection loss.
+ weight_loss_giou (`float`, *optional*, defaults to 2.0):
+ Relative weight of the generalized IoU loss in the object detection loss.
+ eos_coefficient (`float`, *optional*, defaults to 0.0001):
+ Relative classification weight of the 'no-object' class in the object detection loss.
+
+ Examples:
+
+ ```python
+ >>> from transformers import RTDetrConfig, RTDetrModel
+
+ >>> # Initializing a RT-DETR configuration
+ >>> configuration = RTDetrConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = RTDetrModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "rt_detr"
+ layer_types = ["basic", "bottleneck"]
+ attribute_map = {
+ "hidden_size": "d_model",
+ "num_attention_heads": "encoder_attention_heads",
+ }
+
+ def __init__(
+ self,
+ initializer_range=0.01,
+ initializer_bias_prior_prob=None,
+ layer_norm_eps=1e-5,
+ batch_norm_eps=1e-5,
+ # backbone
+ backbone_config=None,
+ backbone=None,
+ use_pretrained_backbone=False,
+ use_timm_backbone=False,
+ freeze_backbone_batch_norms=True,
+ backbone_kwargs=None,
+ # encoder HybridEncoder
+ encoder_hidden_dim=256,
+ encoder_in_channels=[512, 1024, 2048],
+ feat_strides=[8, 16, 32],
+ encoder_layers=1,
+ encoder_ffn_dim=1024,
+ encoder_attention_heads=8,
+ dropout=0.0,
+ activation_dropout=0.0,
+ encode_proj_layers=[2],
+ positional_encoding_temperature=10000,
+ encoder_activation_function="gelu",
+ activation_function="silu",
+ eval_size=None,
+ normalize_before=False,
+ hidden_expansion=1.0,
+ # decoder RTDetrTransformer
+ d_model=256,
+ num_queries=300,
+ decoder_in_channels=[256, 256, 256],
+ decoder_ffn_dim=1024,
+ num_feature_levels=3,
+ decoder_n_points=4,
+ decoder_layers=6,
+ decoder_attention_heads=8,
+ decoder_activation_function="relu",
+ attention_dropout=0.0,
+ num_denoising=100,
+ label_noise_ratio=0.5,
+ box_noise_scale=1.0,
+ learn_initial_query=False,
+ anchor_image_size=None,
+ disable_custom_kernels=True,
+ with_box_refine=True,
+ is_encoder_decoder=True,
+ # Loss
+ matcher_alpha=0.25,
+ matcher_gamma=2.0,
+ matcher_class_cost=2.0,
+ matcher_bbox_cost=5.0,
+ matcher_giou_cost=2.0,
+ use_focal_loss=True,
+ auxiliary_loss=True,
+ focal_loss_alpha=0.75,
+ focal_loss_gamma=2.0,
+ weight_loss_vfl=1.0,
+ weight_loss_bbox=5.0,
+ weight_loss_giou=2.0,
+ eos_coefficient=1e-4,
+ **kwargs,
+ ):
+ self.initializer_range = initializer_range
+ self.initializer_bias_prior_prob = initializer_bias_prior_prob
+ self.layer_norm_eps = layer_norm_eps
+ self.batch_norm_eps = batch_norm_eps
+ # backbone
+ if backbone_config is None and backbone is None:
+ logger.info(
+ "`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetr-ResNet` backbone."
+ )
+ backbone_config = RTDetrResNetConfig(
+ num_channels=3,
+ embedding_size=64,
+ hidden_sizes=[256, 512, 1024, 2048],
+ depths=[3, 4, 6, 3],
+ layer_type="bottleneck",
+ hidden_act="relu",
+ downsample_in_first_stage=False,
+ downsample_in_bottleneck=False,
+ out_features=None,
+ out_indices=[2, 3, 4],
+ )
+ elif isinstance(backbone_config, dict):
+ backbone_model_type = backbone_config.pop("model_type")
+ config_class = CONFIG_MAPPING[backbone_model_type]
+ backbone_config = config_class.from_dict(backbone_config)
+
+ verify_backbone_config_arguments(
+ use_timm_backbone=use_timm_backbone,
+ use_pretrained_backbone=use_pretrained_backbone,
+ backbone=backbone,
+ backbone_config=backbone_config,
+ backbone_kwargs=backbone_kwargs,
+ )
+
+ self.backbone_config = backbone_config
+ self.backbone = backbone
+ self.use_pretrained_backbone = use_pretrained_backbone
+ self.use_timm_backbone = use_timm_backbone
+ self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
+ self.backbone_kwargs = backbone_kwargs
+ # encoder
+ self.encoder_hidden_dim = encoder_hidden_dim
+ self.encoder_in_channels = encoder_in_channels
+ self.feat_strides = feat_strides
+ self.encoder_attention_heads = encoder_attention_heads
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.dropout = dropout
+ self.activation_dropout = activation_dropout
+ self.encode_proj_layers = encode_proj_layers
+ self.encoder_layers = encoder_layers
+ self.positional_encoding_temperature = positional_encoding_temperature
+ self.eval_size = eval_size
+ self.normalize_before = normalize_before
+ self.encoder_activation_function = encoder_activation_function
+ self.activation_function = activation_function
+ self.hidden_expansion = hidden_expansion
+ # decoder
+ self.d_model = d_model
+ self.num_queries = num_queries
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.decoder_in_channels = decoder_in_channels
+ self.num_feature_levels = num_feature_levels
+ self.decoder_n_points = decoder_n_points
+ self.decoder_layers = decoder_layers
+ self.decoder_attention_heads = decoder_attention_heads
+ self.decoder_activation_function = decoder_activation_function
+ self.attention_dropout = attention_dropout
+ self.num_denoising = num_denoising
+ self.label_noise_ratio = label_noise_ratio
+ self.box_noise_scale = box_noise_scale
+ self.learn_initial_query = learn_initial_query
+ self.anchor_image_size = anchor_image_size
+ self.auxiliary_loss = auxiliary_loss
+ self.disable_custom_kernels = disable_custom_kernels
+ self.with_box_refine = with_box_refine
+ # Loss
+ self.matcher_alpha = matcher_alpha
+ self.matcher_gamma = matcher_gamma
+ self.matcher_class_cost = matcher_class_cost
+ self.matcher_bbox_cost = matcher_bbox_cost
+ self.matcher_giou_cost = matcher_giou_cost
+ self.use_focal_loss = use_focal_loss
+ self.focal_loss_alpha = focal_loss_alpha
+ self.focal_loss_gamma = focal_loss_gamma
+ self.weight_loss_vfl = weight_loss_vfl
+ self.weight_loss_bbox = weight_loss_bbox
+ self.weight_loss_giou = weight_loss_giou
+ self.eos_coefficient = eos_coefficient
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
+
+ @property
+ def num_attention_heads(self) -> int:
+ return self.encoder_attention_heads
+
+ @property
+ def hidden_size(self) -> int:
+ return self.d_model
+
+ @property
+ def sub_configs(self):
+ return (
+ {"backbone_config": type(self.backbone_config)}
+ if getattr(self, "backbone_config", None) is not None
+ else {}
+ )
+
+ @classmethod
+ def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
+ """Instantiate a [`RTDetrConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
+ configuration.
+
+ Args:
+ backbone_config ([`PretrainedConfig`]):
+ The backbone configuration.
+
+ Returns:
+ [`RTDetrConfig`]: An instance of a configuration object
+ """
+ return cls(
+ backbone_config=backbone_config,
+ **kwargs,
+ )
+
+
+__all__ = ["RTDetrConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/configuration_rt_detr_resnet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/configuration_rt_detr_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..73b9517ab1498fe7fb91fd8f67a0827a55cacc7d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/configuration_rt_detr_resnet.py
@@ -0,0 +1,114 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""RT-DETR ResNet model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class RTDetrResNetConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`RTDetrResnetBackbone`]. It is used to instantiate an
+ ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the ResNet
+ [microsoft/resnet-50](https://huggingface.co/microsoft/resnet-50) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ embedding_size (`int`, *optional*, defaults to 64):
+ Dimensionality (hidden size) for the embedding layer.
+ hidden_sizes (`list[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`):
+ Dimensionality (hidden size) at each stage.
+ depths (`list[int]`, *optional*, defaults to `[3, 4, 6, 3]`):
+ Depth (number of layers) for each stage.
+ layer_type (`str`, *optional*, defaults to `"bottleneck"`):
+ The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or
+ `"bottleneck"` (used for larger models like resnet-50 and above).
+ hidden_act (`str`, *optional*, defaults to `"relu"`):
+ The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"`
+ are supported.
+ downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
+ If `True`, the first stage will downsample the inputs using a `stride` of 2.
+ downsample_in_bottleneck (`bool`, *optional*, defaults to `False`):
+ If `True`, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a `stride` of 2.
+ out_features (`list[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`list[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+
+ Example:
+ ```python
+ >>> from transformers import RTDetrResNetConfig, RTDetrResnetBackbone
+
+ >>> # Initializing a ResNet resnet-50 style configuration
+ >>> configuration = RTDetrResNetConfig()
+
+ >>> # Initializing a model (with random weights) from the resnet-50 style configuration
+ >>> model = RTDetrResnetBackbone(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "rt_detr_resnet"
+ layer_types = ["basic", "bottleneck"]
+
+ def __init__(
+ self,
+ num_channels=3,
+ embedding_size=64,
+ hidden_sizes=[256, 512, 1024, 2048],
+ depths=[3, 4, 6, 3],
+ layer_type="bottleneck",
+ hidden_act="relu",
+ downsample_in_first_stage=False,
+ downsample_in_bottleneck=False,
+ out_features=None,
+ out_indices=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if layer_type not in self.layer_types:
+ raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}")
+ self.num_channels = num_channels
+ self.embedding_size = embedding_size
+ self.hidden_sizes = hidden_sizes
+ self.depths = depths
+ self.layer_type = layer_type
+ self.hidden_act = hidden_act
+ self.downsample_in_first_stage = downsample_in_first_stage
+ self.downsample_in_bottleneck = downsample_in_bottleneck
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+
+
+__all__ = ["RTDetrResNetConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/image_processing_rt_detr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/image_processing_rt_detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..de61a8019047dbc97761c151d345d06d16cdca4d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/image_processing_rt_detr.py
@@ -0,0 +1,1103 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for RT-DETR."""
+
+import pathlib
+from collections.abc import Iterable
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_processing_utils import BaseImageProcessor, get_size_dict
+from ...image_transforms import (
+ PaddingMode,
+ center_to_corners_format,
+ corners_to_center_format,
+ pad,
+ rescale,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ AnnotationFormat,
+ AnnotationType,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_annotations,
+ validate_preprocess_arguments,
+)
+from ...utils import (
+ filter_out_non_signature_kwargs,
+ is_flax_available,
+ is_jax_tensor,
+ is_tf_available,
+ is_tf_tensor,
+ is_torch_available,
+ is_torch_tensor,
+ logging,
+ requires_backends,
+)
+from ...utils.generic import TensorType
+
+
+if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,)
+
+
+# Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio
+def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]:
+ """
+ Computes the output image size given the input image size and the desired output size.
+
+ Args:
+ image_size (`tuple[int, int]`):
+ The input image size.
+ size (`int`):
+ The desired output size.
+ max_size (`int`, *optional*):
+ The maximum allowed output size.
+ """
+ height, width = image_size
+ raw_size = None
+ if max_size is not None:
+ min_original_size = float(min((height, width)))
+ max_original_size = float(max((height, width)))
+ if max_original_size / min_original_size * size > max_size:
+ raw_size = max_size * min_original_size / max_original_size
+ size = int(round(raw_size))
+
+ if (height <= width and height == size) or (width <= height and width == size):
+ oh, ow = height, width
+ elif width < height:
+ ow = size
+ if max_size is not None and raw_size is not None:
+ oh = int(raw_size * height / width)
+ else:
+ oh = int(size * height / width)
+ else:
+ oh = size
+ if max_size is not None and raw_size is not None:
+ ow = int(raw_size * width / height)
+ else:
+ ow = int(size * width / height)
+
+ return (oh, ow)
+
+
+# Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size
+def get_resize_output_image_size(
+ input_image: np.ndarray,
+ size: Union[int, tuple[int, int], list[int]],
+ max_size: Optional[int] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> tuple[int, int]:
+ """
+ Computes the output image size given the input image size and the desired output size. If the desired output size
+ is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output
+ image size is computed by keeping the aspect ratio of the input image size.
+
+ Args:
+ input_image (`np.ndarray`):
+ The image to resize.
+ size (`int` or `tuple[int, int]` or `list[int]`):
+ The desired output size.
+ max_size (`int`, *optional*):
+ The maximum allowed output size.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
+ """
+ image_size = get_image_size(input_image, input_data_format)
+ if isinstance(size, (list, tuple)):
+ return size
+
+ return get_size_with_aspect_ratio(image_size, size, max_size)
+
+
+# Copied from transformers.models.detr.image_processing_detr.get_image_size_for_max_height_width
+def get_image_size_for_max_height_width(
+ input_image: np.ndarray,
+ max_height: int,
+ max_width: int,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> tuple[int, int]:
+ """
+ Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
+ Important, even if image_height < max_height and image_width < max_width, the image will be resized
+ to at least one of the edges be equal to max_height or max_width.
+ For example:
+ - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
+ - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
+ Args:
+ input_image (`np.ndarray`):
+ The image to resize.
+ max_height (`int`):
+ The maximum allowed height.
+ max_width (`int`):
+ The maximum allowed width.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
+ """
+ image_size = get_image_size(input_image, input_data_format)
+ height, width = image_size
+ height_scale = max_height / height
+ width_scale = max_width / width
+ min_scale = min(height_scale, width_scale)
+ new_height = int(height * min_scale)
+ new_width = int(width * min_scale)
+ return new_height, new_width
+
+
+# Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn
+def get_numpy_to_framework_fn(arr) -> Callable:
+ """
+ Returns a function that converts a numpy array to the framework of the input array.
+
+ Args:
+ arr (`np.ndarray`): The array to convert.
+ """
+ if isinstance(arr, np.ndarray):
+ return np.array
+ if is_tf_available() and is_tf_tensor(arr):
+ import tensorflow as tf
+
+ return tf.convert_to_tensor
+ if is_torch_available() and is_torch_tensor(arr):
+ import torch
+
+ return torch.tensor
+ if is_flax_available() and is_jax_tensor(arr):
+ import jax.numpy as jnp
+
+ return jnp.array
+ raise ValueError(f"Cannot convert arrays of type {type(arr)}")
+
+
+# Copied from transformers.models.detr.image_processing_detr.safe_squeeze
+def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray:
+ """
+ Squeezes an array, but only if the axis specified has dim 1.
+ """
+ if axis is None:
+ return arr.squeeze()
+
+ try:
+ return arr.squeeze(axis=axis)
+ except ValueError:
+ return arr
+
+
+# Copied from transformers.models.detr.image_processing_detr.normalize_annotation
+def normalize_annotation(annotation: dict, image_size: tuple[int, int]) -> dict:
+ image_height, image_width = image_size
+ norm_annotation = {}
+ for key, value in annotation.items():
+ if key == "boxes":
+ boxes = value
+ boxes = corners_to_center_format(boxes)
+ boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32)
+ norm_annotation[key] = boxes
+ else:
+ norm_annotation[key] = value
+ return norm_annotation
+
+
+# Copied from transformers.models.detr.image_processing_detr.max_across_indices
+def max_across_indices(values: Iterable[Any]) -> list[Any]:
+ """
+ Return the maximum value across all indices of an iterable of values.
+ """
+ return [max(values_i) for values_i in zip(*values)]
+
+
+# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
+def get_max_height_width(
+ images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> list[int]:
+ """
+ Get the maximum height and width across all images in a batch.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if input_data_format == ChannelDimension.FIRST:
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
+ elif input_data_format == ChannelDimension.LAST:
+ max_height, max_width, _ = max_across_indices([img.shape for img in images])
+ else:
+ raise ValueError(f"Invalid channel dimension format: {input_data_format}")
+ return (max_height, max_width)
+
+
+# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
+def make_pixel_mask(
+ image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> np.ndarray:
+ """
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+
+ Args:
+ image (`np.ndarray`):
+ Image to make the pixel mask for.
+ output_size (`tuple[int, int]`):
+ Output size of the mask.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ mask = np.zeros(output_size, dtype=np.int64)
+ mask[:input_height, :input_width] = 1
+ return mask
+
+
+def prepare_coco_detection_annotation(
+ image,
+ target,
+ return_segmentation_masks: bool = False,
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
+):
+ """
+ Convert the target in COCO format into the format expected by RTDETR.
+ """
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
+
+ image_id = target["image_id"]
+ image_id = np.asarray([image_id], dtype=np.int64)
+
+ # Get all COCO annotations for the given image.
+ annotations = target["annotations"]
+ annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0]
+
+ classes = [obj["category_id"] for obj in annotations]
+ classes = np.asarray(classes, dtype=np.int64)
+
+ # for conversion to coco api
+ area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32)
+ iscrowd = np.asarray([obj.get("iscrowd", 0) for obj in annotations], dtype=np.int64)
+
+ boxes = [obj["bbox"] for obj in annotations]
+ # guard against no boxes via resizing
+ boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4)
+ boxes[:, 2:] += boxes[:, :2]
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
+
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+
+ new_target = {}
+ new_target["image_id"] = image_id
+ new_target["class_labels"] = classes[keep]
+ new_target["boxes"] = boxes[keep]
+ new_target["area"] = area[keep]
+ new_target["iscrowd"] = iscrowd[keep]
+ new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64)
+
+ if annotations and "keypoints" in annotations[0]:
+ keypoints = [obj["keypoints"] for obj in annotations]
+ # Converting the filtered keypoints list to a numpy array
+ keypoints = np.asarray(keypoints, dtype=np.float32)
+ # Apply the keep mask here to filter the relevant annotations
+ keypoints = keypoints[keep]
+ num_keypoints = keypoints.shape[0]
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
+ new_target["keypoints"] = keypoints
+
+ return new_target
+
+
+# Copied from transformers.models.detr.image_processing_detr.resize_annotation
+def resize_annotation(
+ annotation: dict[str, Any],
+ orig_size: tuple[int, int],
+ target_size: tuple[int, int],
+ threshold: float = 0.5,
+ resample: PILImageResampling = PILImageResampling.NEAREST,
+):
+ """
+ Resizes an annotation to a target size.
+
+ Args:
+ annotation (`dict[str, Any]`):
+ The annotation dictionary.
+ orig_size (`tuple[int, int]`):
+ The original size of the input image.
+ target_size (`tuple[int, int]`):
+ The target size of the image, as returned by the preprocessing `resize` step.
+ threshold (`float`, *optional*, defaults to 0.5):
+ The threshold used to binarize the segmentation masks.
+ resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`):
+ The resampling filter to use when resizing the masks.
+ """
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size))
+ ratio_height, ratio_width = ratios
+
+ new_annotation = {}
+ new_annotation["size"] = target_size
+
+ for key, value in annotation.items():
+ if key == "boxes":
+ boxes = value
+ scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32)
+ new_annotation["boxes"] = scaled_boxes
+ elif key == "area":
+ area = value
+ scaled_area = area * (ratio_width * ratio_height)
+ new_annotation["area"] = scaled_area
+ elif key == "masks":
+ masks = value[:, None]
+ masks = np.array([resize(mask, target_size, resample=resample) for mask in masks])
+ masks = masks.astype(np.float32)
+ masks = masks[:, 0] > threshold
+ new_annotation["masks"] = masks
+ elif key == "size":
+ new_annotation["size"] = target_size
+ else:
+ new_annotation[key] = value
+
+ return new_annotation
+
+
+class RTDetrImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a RT-DETR image processor.
+
+ Args:
+ format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be
+ overridden by the `do_resize` parameter in the `preprocess` method.
+ size (`dict[str, int]` *optional*, defaults to `{"height": 640, "width": 640}`):
+ Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
+ in the `preprocess` method. Available options are:
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
+ Do NOT keep the aspect ratio.
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
+ less or equal to `longest_edge`.
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
+ `max_width`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+ `do_rescale` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
+ `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `False`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
+ Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
+ channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
+ Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
+ for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_convert_annotations (`bool`, *optional*, defaults to `True`):
+ Controls whether to convert the annotations to the format expected by the DETR model. Converts the
+ bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
+ Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
+ do_pad (`bool`, *optional*, defaults to `False`):
+ Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
+ method. If `True`, padding will be applied to the bottom and right of the image with zeros.
+ If `pad_size` is provided, the image will be padded to the specified dimensions.
+ Otherwise, the image will be padded to the maximum height and width of the batch.
+ pad_size (`dict[str, int]`, *optional*):
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
+ height and width in the batch.
+ """
+
+ model_input_names = ["pixel_values", "pixel_mask"]
+
+ def __init__(
+ self,
+ format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = False,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_annotations: bool = True,
+ do_pad: bool = False,
+ pad_size: Optional[dict[str, int]] = None,
+ **kwargs,
+ ) -> None:
+ size = size if size is not None else {"height": 640, "width": 640}
+ size = get_size_dict(size, default_to_square=False)
+
+ if do_convert_annotations is None:
+ do_convert_annotations = do_normalize
+
+ super().__init__(**kwargs)
+ self.format = format
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.do_convert_annotations = do_convert_annotations
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
+ self.do_pad = do_pad
+ self.pad_size = pad_size
+
+ def prepare_annotation(
+ self,
+ image: np.ndarray,
+ target: dict,
+ format: Optional[AnnotationFormat] = None,
+ return_segmentation_masks: Optional[bool] = None,
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> dict:
+ """
+ Prepare an annotation for feeding into RTDETR model.
+ """
+ format = format if format is not None else self.format
+
+ if format == AnnotationFormat.COCO_DETECTION:
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
+ target = prepare_coco_detection_annotation(
+ image, target, return_segmentation_masks, input_data_format=input_data_format
+ )
+ else:
+ raise ValueError(f"Format {format} is not supported.")
+ return target
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
+ int, smaller edge of the image will be matched to this number.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
+ Do NOT keep the aspect ratio.
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
+ less or equal to `longest_edge`.
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
+ `max_width`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ if "max_size" in kwargs:
+ logger.warning_once(
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
+ "Please specify in `size['longest_edge'] instead`.",
+ )
+ max_size = kwargs.pop("max_size")
+ else:
+ max_size = None
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
+ if "shortest_edge" in size and "longest_edge" in size:
+ new_size = get_resize_output_image_size(
+ image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
+ )
+ elif "max_height" in size and "max_width" in size:
+ new_size = get_image_size_for_max_height_width(
+ image, size["max_height"], size["max_width"], input_data_format=input_data_format
+ )
+ elif "height" in size and "width" in size:
+ new_size = (size["height"], size["width"])
+ else:
+ raise ValueError(
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
+ f" {size.keys()}."
+ )
+ image = resize(
+ image,
+ size=new_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+ return image
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation
+ def resize_annotation(
+ self,
+ annotation,
+ orig_size,
+ size,
+ resample: PILImageResampling = PILImageResampling.NEAREST,
+ ) -> dict:
+ """
+ Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched
+ to this number.
+ """
+ return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample)
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
+ def rescale(
+ self,
+ image: np.ndarray,
+ rescale_factor: float,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Rescale the image by the given factor. image = image * rescale_factor.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ rescale_factor (`float`):
+ The value to use for rescaling.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. If unset, is inferred from the input image. Can be
+ one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
+ def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict:
+ """
+ Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to
+ `[center_x, center_y, width, height]` format and from absolute to relative pixel values.
+ """
+ return normalize_annotation(annotation, image_size=image_size)
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._update_annotation_for_padded_image
+ def _update_annotation_for_padded_image(
+ self,
+ annotation: dict,
+ input_image_size: tuple[int, int],
+ output_image_size: tuple[int, int],
+ padding,
+ update_bboxes,
+ ) -> dict:
+ """
+ Update the annotation for a padded image.
+ """
+ new_annotation = {}
+ new_annotation["size"] = output_image_size
+
+ for key, value in annotation.items():
+ if key == "masks":
+ masks = value
+ masks = pad(
+ masks,
+ padding,
+ mode=PaddingMode.CONSTANT,
+ constant_values=0,
+ input_data_format=ChannelDimension.FIRST,
+ )
+ masks = safe_squeeze(masks, 1)
+ new_annotation["masks"] = masks
+ elif key == "boxes" and update_bboxes:
+ boxes = value
+ boxes *= np.asarray(
+ [
+ input_image_size[1] / output_image_size[1],
+ input_image_size[0] / output_image_size[0],
+ input_image_size[1] / output_image_size[1],
+ input_image_size[0] / output_image_size[0],
+ ]
+ )
+ new_annotation["boxes"] = boxes
+ elif key == "size":
+ new_annotation["size"] = output_image_size
+ else:
+ new_annotation[key] = value
+ return new_annotation
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
+ def _pad_image(
+ self,
+ image: np.ndarray,
+ output_size: tuple[int, int],
+ annotation: Optional[dict[str, Any]] = None,
+ constant_values: Union[float, Iterable[float]] = 0,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ update_bboxes: bool = True,
+ ) -> np.ndarray:
+ """
+ Pad an image with zeros to the given size.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ output_height, output_width = output_size
+
+ pad_bottom = output_height - input_height
+ pad_right = output_width - input_width
+ padding = ((0, pad_bottom), (0, pad_right))
+ padded_image = pad(
+ image,
+ padding,
+ mode=PaddingMode.CONSTANT,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ if annotation is not None:
+ annotation = self._update_annotation_for_padded_image(
+ annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes
+ )
+ return padded_image, annotation
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad
+ def pad(
+ self,
+ images: list[np.ndarray],
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
+ constant_values: Union[float, Iterable[float]] = 0,
+ return_pixel_mask: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ update_bboxes: bool = True,
+ pad_size: Optional[dict[str, int]] = None,
+ ) -> BatchFeature:
+ """
+ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
+ in the batch and optionally returns their corresponding pixel mask.
+
+ Args:
+ images (list[`np.ndarray`]):
+ Images to pad.
+ annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
+ Annotations to transform according to the padding that is applied to the images.
+ constant_values (`float` or `Iterable[float]`, *optional*):
+ The value to use for the padding if `mode` is `"constant"`.
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
+ Whether to return a pixel mask.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ update_bboxes (`bool`, *optional*, defaults to `True`):
+ Whether to update the bounding boxes in the annotations to match the padded images. If the
+ bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)`
+ format, the bounding boxes will not be updated.
+ pad_size (`dict[str, int]`, *optional*):
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
+ height and width in the batch.
+ """
+ pad_size = pad_size if pad_size is not None else self.pad_size
+ if pad_size is not None:
+ padded_size = (pad_size["height"], pad_size["width"])
+ else:
+ padded_size = get_max_height_width(images, input_data_format=input_data_format)
+
+ annotation_list = annotations if annotations is not None else [None] * len(images)
+ padded_images = []
+ padded_annotations = []
+ for image, annotation in zip(images, annotation_list):
+ padded_image, padded_annotation = self._pad_image(
+ image,
+ padded_size,
+ annotation,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ update_bboxes=update_bboxes,
+ )
+ padded_images.append(padded_image)
+ padded_annotations.append(padded_annotation)
+
+ data = {"pixel_values": padded_images}
+
+ if return_pixel_mask:
+ masks = [
+ make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format)
+ for image in images
+ ]
+ data["pixel_mask"] = masks
+
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ if annotations is not None:
+ encoded_inputs["labels"] = [
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations
+ ]
+
+ return encoded_inputs
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
+ return_segmentation_masks: Optional[bool] = None,
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample=None, # PILImageResampling
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[Union[int, float]] = None,
+ do_normalize: Optional[bool] = None,
+ do_convert_annotations: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ format: Optional[Union[str, AnnotationFormat]] = None,
+ return_tensors: Optional[Union[TensorType, str]] = None,
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ pad_size: Optional[dict[str, int]] = None,
+ ) -> BatchFeature:
+ """
+ Preprocess an image or a batch of images so that it can be used by the model.
+
+ Args:
+ images (`ImageInput`):
+ Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
+ from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
+ List of annotations associated with the image or batch of images. If annotation is for object
+ detection, the annotations should be a dictionary with the following keys:
+ - "image_id" (`int`): The image id.
+ - "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a
+ dictionary. An image can have no annotations, in which case the list should be empty.
+ If annotation is for segmentation, the annotations should be a dictionary with the following keys:
+ - "image_id" (`int`): The image id.
+ - "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary.
+ An image can have no segments, in which case the list should be empty.
+ - "file_name" (`str`): The file name of the image.
+ return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
+ Whether to return segmentation masks.
+ masks_path (`str` or `pathlib.Path`, *optional*):
+ Path to the directory containing the segmentation masks.
+ do_resize (`bool`, *optional*, defaults to self.do_resize):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to self.size):
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
+ Do NOT keep the aspect ratio.
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
+ less or equal to `longest_edge`.
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
+ `max_width`.
+ resample (`PILImageResampling`, *optional*, defaults to self.resample):
+ Resampling filter to use when resizing the image.
+ do_rescale (`bool`, *optional*, defaults to self.do_rescale):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
+ Rescale factor to use when rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to self.do_normalize):
+ Whether to normalize the image.
+ do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
+ Whether to convert the annotations to the format expected by the model. Converts the bounding
+ boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
+ and in relative coordinates.
+ image_mean (`float` or `list[float]`, *optional*, defaults to self.image_mean):
+ Mean to use when normalizing the image.
+ image_std (`float` or `list[float]`, *optional*, defaults to self.image_std):
+ Standard deviation to use when normalizing the image.
+ do_pad (`bool`, *optional*, defaults to self.do_pad):
+ Whether to pad the image. If `True`, padding will be applied to the bottom and right of
+ the image with zeros. If `pad_size` is provided, the image will be padded to the specified
+ dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
+ format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
+ Format of the annotations.
+ return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
+ Type of tensors to return. If `None`, will return the list of images.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ pad_size (`dict[str, int]`, *optional*):
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
+ height and width in the batch.
+ """
+ do_resize = self.do_resize if do_resize is None else do_resize
+ size = self.size if size is None else size
+ size = get_size_dict(size=size, default_to_square=True)
+ resample = self.resample if resample is None else resample
+ do_rescale = self.do_rescale if do_rescale is None else do_rescale
+ rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
+ do_normalize = self.do_normalize if do_normalize is None else do_normalize
+ image_mean = self.image_mean if image_mean is None else image_mean
+ image_std = self.image_std if image_std is None else image_std
+ do_convert_annotations = (
+ self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
+ )
+ do_pad = self.do_pad if do_pad is None else do_pad
+ pad_size = self.pad_size if pad_size is None else pad_size
+ format = self.format if format is None else format
+
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated.
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if annotations is not None and isinstance(annotations, dict):
+ annotations = [annotations]
+
+ if annotations is not None and len(images) != len(annotations):
+ raise ValueError(
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
+ )
+
+ format = AnnotationFormat(format)
+ if annotations is not None:
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
+
+ images = make_flat_list_of_images(images)
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ # All transformations expect numpy arrays
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
+ if annotations is not None:
+ prepared_images = []
+ prepared_annotations = []
+ for image, target in zip(images, annotations):
+ target = self.prepare_annotation(
+ image,
+ target,
+ format,
+ return_segmentation_masks=return_segmentation_masks,
+ masks_path=masks_path,
+ input_data_format=input_data_format,
+ )
+ prepared_images.append(image)
+ prepared_annotations.append(target)
+ images = prepared_images
+ annotations = prepared_annotations
+ del prepared_images, prepared_annotations
+
+ # transformations
+ if do_resize:
+ if annotations is not None:
+ resized_images, resized_annotations = [], []
+ for image, target in zip(images, annotations):
+ orig_size = get_image_size(image, input_data_format)
+ resized_image = self.resize(
+ image, size=size, resample=resample, input_data_format=input_data_format
+ )
+ resized_annotation = self.resize_annotation(
+ target, orig_size, get_image_size(resized_image, input_data_format)
+ )
+ resized_images.append(resized_image)
+ resized_annotations.append(resized_annotation)
+ images = resized_images
+ annotations = resized_annotations
+ del resized_images, resized_annotations
+ else:
+ images = [
+ self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_rescale:
+ images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
+
+ if do_normalize:
+ images = [
+ self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
+ ]
+
+ if do_convert_annotations and annotations is not None:
+ annotations = [
+ self.normalize_annotation(annotation, get_image_size(image, input_data_format))
+ for annotation, image in zip(annotations, images)
+ ]
+
+ if do_pad:
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
+ encoded_inputs = self.pad(
+ images,
+ annotations=annotations,
+ return_pixel_mask=True,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ update_bboxes=do_convert_annotations,
+ return_tensors=return_tensors,
+ pad_size=pad_size,
+ )
+ else:
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ for image in images
+ ]
+ encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
+ if annotations is not None:
+ encoded_inputs["labels"] = [
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
+ ]
+
+ return encoded_inputs
+
+ def post_process_object_detection(
+ self,
+ outputs,
+ threshold: float = 0.5,
+ target_sizes: Union[TensorType, list[tuple]] = None,
+ use_focal_loss: bool = True,
+ ):
+ """
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
+
+ Args:
+ outputs ([`DetrObjectDetectionOutput`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*, defaults to 0.5):
+ Score threshold to keep object detection predictions.
+ target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
+ use_focal_loss (`bool` defaults to `True`):
+ Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied
+ to compute the scores of each detection, otherwise, a softmax function is used.
+
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model.
+ """
+ requires_backends(self, ["torch"])
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
+ # convert from relative cxcywh to absolute xyxy
+ boxes = center_to_corners_format(out_bbox)
+ if target_sizes is not None:
+ if len(out_logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+ if isinstance(target_sizes, list):
+ img_h, img_w = torch.as_tensor(target_sizes).unbind(1)
+ else:
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
+ boxes = boxes * scale_fct[:, None, :]
+
+ num_top_queries = out_logits.shape[1]
+ num_classes = out_logits.shape[2]
+
+ if use_focal_loss:
+ scores = torch.nn.functional.sigmoid(out_logits)
+ scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
+ labels = index % num_classes
+ index = index // num_classes
+ boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
+ else:
+ scores = torch.nn.functional.softmax(out_logits)[:, :, :-1]
+ scores, labels = scores.max(dim=-1)
+ if scores.shape[1] > num_top_queries:
+ scores, index = torch.topk(scores, num_top_queries, dim=-1)
+ labels = torch.gather(labels, dim=1, index=index)
+ boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
+
+ results = []
+ for score, label, box in zip(scores, labels, boxes):
+ results.append(
+ {
+ "scores": score[score > threshold],
+ "labels": label[score > threshold],
+ "boxes": box[score > threshold],
+ }
+ )
+
+ return results
+
+
+__all__ = ["RTDetrImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/image_processing_rt_detr_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/image_processing_rt_detr_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..9aae271deaccbab9784dd7e02059eba29d32d5f2
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/image_processing_rt_detr_fast.py
@@ -0,0 +1,559 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/rt_detr/modular_rt_detr.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_rt_detr.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+import pathlib
+from typing import Any, Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ SizeDict,
+ get_image_size_for_max_height_width,
+ get_max_height_width,
+ safe_squeeze,
+)
+from ...image_transforms import center_to_corners_format, corners_to_center_format
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ AnnotationFormat,
+ AnnotationType,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ validate_annotations,
+)
+from ...processing_utils import Unpack
+from ...utils import TensorType, auto_docstring, requires_backends
+from ...utils.import_utils import requires
+from .image_processing_rt_detr import get_size_with_aspect_ratio
+
+
+class RTDetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ r"""
+ format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
+ do_convert_annotations (`bool`, *optional*, defaults to `True`):
+ Controls whether to convert the annotations to the format expected by the RT_DETR model. Converts the
+ bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
+ Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
+ return_segmentation_masks (`bool`, *optional*, defaults to `False`):
+ Whether to return segmentation masks.
+ """
+
+ format: Optional[Union[str, AnnotationFormat]]
+ do_convert_annotations: Optional[bool]
+ return_segmentation_masks: Optional[bool]
+
+
+SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
+
+
+def prepare_coco_detection_annotation(
+ image,
+ target,
+ return_segmentation_masks: bool = False,
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
+):
+ """
+ Convert the target in COCO format into the format expected by RT-DETR.
+ """
+ image_height, image_width = image.size()[-2:]
+
+ image_id = target["image_id"]
+ image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device)
+
+ # Get all COCO annotations for the given image.
+ annotations = target["annotations"]
+ classes = []
+ area = []
+ boxes = []
+ keypoints = []
+ for obj in annotations:
+ if "iscrowd" not in obj or obj["iscrowd"] == 0:
+ classes.append(obj["category_id"])
+ area.append(obj["area"])
+ boxes.append(obj["bbox"])
+ if "keypoints" in obj:
+ keypoints.append(obj["keypoints"])
+
+ classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device)
+ area = torch.as_tensor(area, dtype=torch.float32, device=image.device)
+ iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device)
+ # guard against no boxes via resizing
+ boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4)
+ boxes[:, 2:] += boxes[:, :2]
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
+
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+
+ new_target = {
+ "image_id": image_id,
+ "class_labels": classes[keep],
+ "boxes": boxes[keep],
+ "area": area[keep],
+ "iscrowd": iscrowd[keep],
+ "orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device),
+ }
+
+ if keypoints:
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device)
+ # Apply the keep mask here to filter the relevant annotations
+ keypoints = keypoints[keep]
+ num_keypoints = keypoints.shape[0]
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
+ new_target["keypoints"] = keypoints
+
+ return new_target
+
+
+@auto_docstring
+@requires(backends=("torchvision", "torch"))
+class RTDetrImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_DEFAULT_MEAN
+ image_std = IMAGENET_DEFAULT_STD
+ format = AnnotationFormat.COCO_DETECTION
+ do_resize = True
+ do_rescale = True
+ do_normalize = False
+ do_pad = False
+ size = {"height": 640, "width": 640}
+ default_to_square = False
+ model_input_names = ["pixel_values", "pixel_mask"]
+ valid_kwargs = RTDetrFastImageProcessorKwargs
+ do_convert_annotations = True
+
+ def __init__(self, **kwargs: Unpack[RTDetrFastImageProcessorKwargs]) -> None:
+ # Backwards compatibility
+ do_convert_annotations = kwargs.get("do_convert_annotations")
+ do_normalize = kwargs.get("do_normalize")
+ if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
+ self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
+
+ super().__init__(**kwargs)
+
+ def prepare_annotation(
+ self,
+ image: torch.Tensor,
+ target: dict,
+ format: Optional[AnnotationFormat] = None,
+ return_segmentation_masks: Optional[bool] = None,
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> dict:
+ """
+ Prepare an annotation for feeding into RT_DETR model.
+ """
+ format = format if format is not None else self.format
+
+ if format == AnnotationFormat.COCO_DETECTION:
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
+ target = prepare_coco_detection_annotation(
+ image, target, return_segmentation_masks, input_data_format=input_data_format
+ )
+ else:
+ raise ValueError(f"Format {format} is not supported.")
+ return target
+
+ def resize(
+ self,
+ image: torch.Tensor,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
+ int, smaller edge of the image will be matched to this number.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to resize.
+ size (`SizeDict`):
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
+ Do NOT keep the aspect ratio.
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
+ less or equal to `longest_edge`.
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
+ `max_width`.
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ Resampling filter to use if resizing the image.
+ """
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
+ if size.shortest_edge and size.longest_edge:
+ # Resize the image so that the shortest edge or the longest edge is of the given size
+ # while maintaining the aspect ratio of the original image.
+ new_size = get_size_with_aspect_ratio(
+ image.size()[-2:],
+ size["shortest_edge"],
+ size["longest_edge"],
+ )
+ elif size.max_height and size.max_width:
+ new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"])
+ elif size.height and size.width:
+ new_size = (size["height"], size["width"])
+ else:
+ raise ValueError(
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
+ f" {size.keys()}."
+ )
+
+ image = F.resize(
+ image,
+ size=new_size,
+ interpolation=interpolation,
+ **kwargs,
+ )
+ return image
+
+ def resize_annotation(
+ self,
+ annotation: dict[str, Any],
+ orig_size: tuple[int, int],
+ target_size: tuple[int, int],
+ threshold: float = 0.5,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ ):
+ """
+ Resizes an annotation to a target size.
+
+ Args:
+ annotation (`dict[str, Any]`):
+ The annotation dictionary.
+ orig_size (`tuple[int, int]`):
+ The original size of the input image.
+ target_size (`tuple[int, int]`):
+ The target size of the image, as returned by the preprocessing `resize` step.
+ threshold (`float`, *optional*, defaults to 0.5):
+ The threshold used to binarize the segmentation masks.
+ resample (`InterpolationMode`, defaults to `F.InterpolationMode.NEAREST_EXACT`):
+ The resampling filter to use when resizing the masks.
+ """
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.NEAREST_EXACT
+ ratio_height, ratio_width = [target / orig for target, orig in zip(target_size, orig_size)]
+
+ new_annotation = {}
+ new_annotation["size"] = target_size
+
+ for key, value in annotation.items():
+ if key == "boxes":
+ boxes = value
+ scaled_boxes = boxes * torch.as_tensor(
+ [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32, device=boxes.device
+ )
+ new_annotation["boxes"] = scaled_boxes
+ elif key == "area":
+ area = value
+ scaled_area = area * (ratio_width * ratio_height)
+ new_annotation["area"] = scaled_area
+ elif key == "masks":
+ masks = value[:, None]
+ masks = [F.resize(mask, target_size, interpolation=interpolation) for mask in masks]
+ masks = torch.stack(masks).to(torch.float32)
+ masks = masks[:, 0] > threshold
+ new_annotation["masks"] = masks
+ elif key == "size":
+ new_annotation["size"] = target_size
+ else:
+ new_annotation[key] = value
+
+ return new_annotation
+
+ def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict:
+ image_height, image_width = image_size
+ norm_annotation = {}
+ for key, value in annotation.items():
+ if key == "boxes":
+ boxes = value
+ boxes = corners_to_center_format(boxes)
+ boxes /= torch.as_tensor(
+ [image_width, image_height, image_width, image_height], dtype=torch.float32, device=boxes.device
+ )
+ norm_annotation[key] = boxes
+ else:
+ norm_annotation[key] = value
+ return norm_annotation
+
+ def _update_annotation_for_padded_image(
+ self,
+ annotation: dict,
+ input_image_size: tuple[int, int],
+ output_image_size: tuple[int, int],
+ padding,
+ update_bboxes,
+ ) -> dict:
+ """
+ Update the annotation for a padded image.
+ """
+ new_annotation = {}
+ new_annotation["size"] = output_image_size
+ ratio_height, ratio_width = (input / output for output, input in zip(output_image_size, input_image_size))
+
+ for key, value in annotation.items():
+ if key == "masks":
+ masks = value
+ masks = F.pad(
+ masks,
+ padding,
+ fill=0,
+ )
+ masks = safe_squeeze(masks, 1)
+ new_annotation["masks"] = masks
+ elif key == "boxes" and update_bboxes:
+ boxes = value
+ boxes *= torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height], device=boxes.device)
+ new_annotation["boxes"] = boxes
+ elif key == "size":
+ new_annotation["size"] = output_image_size
+ else:
+ new_annotation[key] = value
+ return new_annotation
+
+ def pad(
+ self,
+ image: torch.Tensor,
+ padded_size: tuple[int, int],
+ annotation: Optional[dict[str, Any]] = None,
+ update_bboxes: bool = True,
+ fill: int = 0,
+ ):
+ original_size = image.size()[-2:]
+ padding_bottom = padded_size[0] - original_size[0]
+ padding_right = padded_size[1] - original_size[1]
+ if padding_bottom < 0 or padding_right < 0:
+ raise ValueError(
+ f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
+ f"original size. Got padded size: {padded_size}, original size: {original_size}."
+ )
+ if original_size != padded_size:
+ padding = [0, 0, padding_right, padding_bottom]
+ image = F.pad(image, padding, fill=fill)
+ if annotation is not None:
+ annotation = self._update_annotation_for_padded_image(
+ annotation, original_size, padded_size, padding, update_bboxes
+ )
+
+ # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+ pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device)
+ pixel_mask[: original_size[0], : original_size[1]] = 1
+
+ return image, pixel_mask, annotation
+
+ @auto_docstring
+ def preprocess(
+ self,
+ images: ImageInput,
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
+ **kwargs: Unpack[RTDetrFastImageProcessorKwargs],
+ ) -> BatchFeature:
+ r"""
+ annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
+ List of annotations associated with the image or batch of images. If annotation is for object
+ detection, the annotations should be a dictionary with the following keys:
+ - "image_id" (`int`): The image id.
+ - "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a
+ dictionary. An image can have no annotations, in which case the list should be empty.
+ If annotation is for segmentation, the annotations should be a dictionary with the following keys:
+ - "image_id" (`int`): The image id.
+ - "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary.
+ An image can have no segments, in which case the list should be empty.
+ - "file_name" (`str`): The file name of the image.
+ masks_path (`str` or `pathlib.Path`, *optional*):
+ Path to the directory containing the segmentation masks.
+ """
+ return super().preprocess(images, annotations, masks_path, **kwargs)
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]],
+ masks_path: Optional[Union[str, pathlib.Path]],
+ return_segmentation_masks: bool,
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ do_convert_annotations: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ do_pad: bool,
+ pad_size: Optional[SizeDict],
+ format: Optional[Union[str, AnnotationFormat]],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Preprocess an image or a batch of images so that it can be used by the model.
+ """
+
+ if annotations is not None and isinstance(annotations, dict):
+ annotations = [annotations]
+
+ if annotations is not None and len(images) != len(annotations):
+ raise ValueError(
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
+ )
+
+ format = AnnotationFormat(format)
+ if annotations is not None:
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
+
+ data = {}
+ processed_images = []
+ processed_annotations = []
+ pixel_masks = [] # Initialize pixel_masks here
+ for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
+ if annotations is not None:
+ annotation = self.prepare_annotation(
+ image,
+ annotation,
+ format,
+ return_segmentation_masks=return_segmentation_masks,
+ masks_path=masks_path,
+ input_data_format=ChannelDimension.FIRST,
+ )
+
+ if do_resize:
+ resized_image = self.resize(image, size=size, interpolation=interpolation)
+ if annotations is not None:
+ annotation = self.resize_annotation(
+ annotation,
+ orig_size=image.size()[-2:],
+ target_size=resized_image.size()[-2:],
+ )
+ image = resized_image
+ # Fused rescale and normalize
+ image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
+ if do_convert_annotations and annotations is not None:
+ annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
+
+ processed_images.append(image)
+ processed_annotations.append(annotation)
+ images = processed_images
+ annotations = processed_annotations if annotations is not None else None
+
+ if do_pad:
+ # depends on all resized image shapes so we need another loop
+ if pad_size is not None:
+ padded_size = (pad_size.height, pad_size.width)
+ else:
+ padded_size = get_max_height_width(images)
+
+ padded_images = []
+ padded_annotations = []
+ for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
+ if padded_size == image.size()[-2:]:
+ padded_images.append(image)
+ pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device))
+ padded_annotations.append(annotation)
+ continue
+ image, pixel_mask, annotation = self.pad(
+ image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations
+ )
+ padded_images.append(image)
+ padded_annotations.append(annotation)
+ pixel_masks.append(pixel_mask)
+ images = padded_images
+ annotations = padded_annotations if annotations is not None else None
+ data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)})
+
+ data.update({"pixel_values": torch.stack(images, dim=0)})
+ encoded_inputs = BatchFeature(data, tensor_type=return_tensors)
+ if annotations is not None:
+ encoded_inputs["labels"] = [
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
+ ]
+ return encoded_inputs
+
+ def post_process_object_detection(
+ self,
+ outputs,
+ threshold: float = 0.5,
+ target_sizes: Union[TensorType, list[tuple]] = None,
+ use_focal_loss: bool = True,
+ ):
+ """
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
+
+ Args:
+ outputs ([`DetrObjectDetectionOutput`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*, defaults to 0.5):
+ Score threshold to keep object detection predictions.
+ target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
+ use_focal_loss (`bool` defaults to `True`):
+ Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied
+ to compute the scores of each detection, otherwise, a softmax function is used.
+
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model.
+ """
+ requires_backends(self, ["torch"])
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
+ # convert from relative cxcywh to absolute xyxy
+ boxes = center_to_corners_format(out_bbox)
+ if target_sizes is not None:
+ if len(out_logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+ if isinstance(target_sizes, list):
+ img_h, img_w = torch.as_tensor(target_sizes).unbind(1)
+ else:
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
+ boxes = boxes * scale_fct[:, None, :]
+
+ num_top_queries = out_logits.shape[1]
+ num_classes = out_logits.shape[2]
+
+ if use_focal_loss:
+ scores = torch.nn.functional.sigmoid(out_logits)
+ scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
+ labels = index % num_classes
+ index = index // num_classes
+ boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
+ else:
+ scores = torch.nn.functional.softmax(out_logits)[:, :, :-1]
+ scores, labels = scores.max(dim=-1)
+ if scores.shape[1] > num_top_queries:
+ scores, index = torch.topk(scores, num_top_queries, dim=-1)
+ labels = torch.gather(labels, dim=1, index=index)
+ boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
+
+ results = []
+ for score, label, box in zip(scores, labels, boxes):
+ results.append(
+ {
+ "scores": score[score > threshold],
+ "labels": label[score > threshold],
+ "boxes": box[score > threshold],
+ }
+ )
+
+ return results
+
+
+__all__ = ["RTDetrImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/modeling_rt_detr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/modeling_rt_detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d4b64496969acb596b037841c745b109f3d11b5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/modeling_rt_detr.py
@@ -0,0 +1,2013 @@
+# coding=utf-8
+# Copyright 2024 Baidu Inc and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch RT-DETR model."""
+
+import math
+import warnings
+from dataclasses import dataclass
+from functools import partial
+from typing import Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from ...activations import ACT2CLS, ACT2FN
+from ...image_transforms import center_to_corners_format, corners_to_center_format
+from ...integrations import use_kernel_forward_from_hub
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import compile_compatible_method_lru_cache
+from ...utils import (
+ ModelOutput,
+ auto_docstring,
+ logging,
+ torch_int,
+)
+from ...utils.backbone_utils import load_backbone
+from .configuration_rt_detr import RTDetrConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+# TODO: Replace all occurrences of the checkpoint with the final one
+
+
+@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
+# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention
+class MultiScaleDeformableAttention(nn.Module):
+ def forward(
+ self,
+ value: Tensor,
+ value_spatial_shapes: Tensor,
+ value_spatial_shapes_list: list[tuple],
+ level_start_index: Tensor,
+ sampling_locations: Tensor,
+ attention_weights: Tensor,
+ im2col_step: int,
+ ):
+ batch_size, _, num_heads, hidden_dim = value.shape
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
+ value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for level_id, (height, width) in enumerate(value_spatial_shapes_list):
+ # batch_size, height*width, num_heads, hidden_dim
+ # -> batch_size, height*width, num_heads*hidden_dim
+ # -> batch_size, num_heads*hidden_dim, height*width
+ # -> batch_size*num_heads, hidden_dim, height, width
+ value_l_ = (
+ value_list[level_id]
+ .flatten(2)
+ .transpose(1, 2)
+ .reshape(batch_size * num_heads, hidden_dim, height, width)
+ )
+ # batch_size, num_queries, num_heads, num_points, 2
+ # -> batch_size, num_heads, num_queries, num_points, 2
+ # -> batch_size*num_heads, num_queries, num_points, 2
+ sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
+ sampling_value_l_ = nn.functional.grid_sample(
+ value_l_,
+ sampling_grid_l_,
+ mode="bilinear",
+ padding_mode="zeros",
+ align_corners=False,
+ )
+ sampling_value_list.append(sampling_value_l_)
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
+ attention_weights = attention_weights.transpose(1, 2).reshape(
+ batch_size * num_heads, 1, num_queries, num_levels * num_points
+ )
+ output = (
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
+ .sum(-1)
+ .view(batch_size, num_heads * hidden_dim, num_queries)
+ )
+ return output.transpose(1, 2).contiguous()
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of the RTDetrDecoder. This class adds two attributes to
+ BaseModelOutputWithCrossAttentions, namely:
+ - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
+ - a stacked tensor of intermediate reference points.
+ """
+)
+class RTDetrDecoderOutput(ModelOutput):
+ r"""
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
+ Stacked intermediate logits (logits of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked initial reference points (initial reference points of each layer of the decoder).
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
+ used to compute the weighted average in the cross-attention heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
+ intermediate_logits: Optional[torch.FloatTensor] = None
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
+ intermediate_predicted_corners: Optional[torch.FloatTensor] = None
+ initial_reference_points: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of the RT-DETR encoder-decoder model.
+ """
+)
+class RTDetrModelOutput(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
+ Stacked intermediate logits (logits of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Initial reference points used for the first decoder layer.
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Initial reference points sent through the Transformer decoder.
+ enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
+ picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e.
+ foreground and background).
+ enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`):
+ Logits of predicted bounding boxes coordinates in the encoder stage.
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
+ foreground and background).
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Logits of predicted bounding boxes coordinates in the first stage.
+ denoising_meta_values (`dict`):
+ Extra dictionary for the denoising related values.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
+ intermediate_logits: Optional[torch.FloatTensor] = None
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
+ intermediate_predicted_corners: Optional[torch.FloatTensor] = None
+ initial_reference_points: Optional[torch.FloatTensor] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ init_reference_points: Optional[torch.FloatTensor] = None
+ enc_topk_logits: Optional[torch.FloatTensor] = None
+ enc_topk_bboxes: Optional[torch.FloatTensor] = None
+ enc_outputs_class: Optional[torch.FloatTensor] = None
+ enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
+ denoising_meta_values: Optional[dict] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`RTDetrForObjectDetection`].
+ """
+)
+class RTDetrObjectDetectionOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
+ scale-invariant IoU loss.
+ loss_dict (`Dict`, *optional*):
+ A dictionary containing the individual losses. Useful for logging.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
+ Classification logits (including no-object) for all queries.
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
+ possible padding). You can use [`~RTDetrImageProcessor.post_process_object_detection`] to retrieve the
+ unnormalized (absolute) bounding boxes.
+ auxiliary_outputs (`list[Dict]`, *optional*):
+ Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
+ `pred_boxes`) for each decoder layer.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`):
+ Stacked intermediate logits (logits of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked initial reference points (initial reference points of each layer of the decoder).
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Initial reference points sent through the Transformer decoder.
+ enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Logits of predicted bounding boxes coordinates in the encoder.
+ enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Logits of predicted bounding boxes coordinates in the encoder.
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
+ foreground and background).
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Logits of predicted bounding boxes coordinates in the first stage.
+ denoising_meta_values (`dict`):
+ Extra dictionary for the denoising related values
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ loss_dict: Optional[dict] = None
+ logits: Optional[torch.FloatTensor] = None
+ pred_boxes: Optional[torch.FloatTensor] = None
+ auxiliary_outputs: Optional[list[dict]] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
+ intermediate_logits: Optional[torch.FloatTensor] = None
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
+ intermediate_predicted_corners: Optional[torch.FloatTensor] = None
+ initial_reference_points: Optional[torch.FloatTensor] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ init_reference_points: Optional[tuple[torch.FloatTensor]] = None
+ enc_topk_logits: Optional[torch.FloatTensor] = None
+ enc_topk_bboxes: Optional[torch.FloatTensor] = None
+ enc_outputs_class: Optional[torch.FloatTensor] = None
+ enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
+ denoising_meta_values: Optional[dict] = None
+
+
+def _get_clones(partial_module, N):
+ return nn.ModuleList([partial_module() for i in range(N)])
+
+
+# Copied from transformers.models.conditional_detr.modeling_conditional_detr.inverse_sigmoid
+def inverse_sigmoid(x, eps=1e-5):
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->RTDetr
+class RTDetrFrozenBatchNorm2d(nn.Module):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
+ torchvision.models.resnet[18,34,50,101] produce nans.
+ """
+
+ def __init__(self, n):
+ super().__init__()
+ self.register_buffer("weight", torch.ones(n))
+ self.register_buffer("bias", torch.zeros(n))
+ self.register_buffer("running_mean", torch.zeros(n))
+ self.register_buffer("running_var", torch.ones(n))
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ num_batches_tracked_key = prefix + "num_batches_tracked"
+ if num_batches_tracked_key in state_dict:
+ del state_dict[num_batches_tracked_key]
+
+ super()._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ )
+
+ def forward(self, x):
+ # move reshapes to the beginning
+ # to make it user-friendly
+ weight = self.weight.reshape(1, -1, 1, 1)
+ bias = self.bias.reshape(1, -1, 1, 1)
+ running_var = self.running_var.reshape(1, -1, 1, 1)
+ running_mean = self.running_mean.reshape(1, -1, 1, 1)
+ epsilon = 1e-5
+ scale = weight * (running_var + epsilon).rsqrt()
+ bias = bias - running_mean * scale
+ return x * scale + bias
+
+
+# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->RTDetr
+def replace_batch_norm(model):
+ r"""
+ Recursively replace all `torch.nn.BatchNorm2d` with `RTDetrFrozenBatchNorm2d`.
+
+ Args:
+ model (torch.nn.Module):
+ input model
+ """
+ for name, module in model.named_children():
+ if isinstance(module, nn.BatchNorm2d):
+ new_module = RTDetrFrozenBatchNorm2d(module.num_features)
+
+ if module.weight.device != torch.device("meta"):
+ new_module.weight.data.copy_(module.weight)
+ new_module.bias.data.copy_(module.bias)
+ new_module.running_mean.data.copy_(module.running_mean)
+ new_module.running_var.data.copy_(module.running_var)
+
+ model._modules[name] = new_module
+
+ if len(list(module.children())) > 0:
+ replace_batch_norm(module)
+
+
+def get_contrastive_denoising_training_group(
+ targets,
+ num_classes,
+ num_queries,
+ class_embed,
+ num_denoising_queries=100,
+ label_noise_ratio=0.5,
+ box_noise_scale=1.0,
+):
+ """
+ Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes.
+
+ Args:
+ targets (`list[dict]`):
+ The target objects, each containing 'class_labels' and 'boxes' for objects in an image.
+ num_classes (`int`):
+ Total number of classes in the dataset.
+ num_queries (`int`):
+ Number of query slots in the transformer.
+ class_embed (`callable`):
+ A function or a model layer to embed class labels.
+ num_denoising_queries (`int`, *optional*, defaults to 100):
+ Number of denoising queries.
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
+ Ratio of noise applied to labels.
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
+ Scale of noise applied to bounding boxes.
+ Returns:
+ `tuple` comprising various elements:
+ - **input_query_class** (`torch.FloatTensor`) --
+ Class queries with applied label noise.
+ - **input_query_bbox** (`torch.FloatTensor`) --
+ Bounding box queries with applied box noise.
+ - **attn_mask** (`torch.FloatTensor`) --
+ Attention mask for separating denoising and reconstruction queries.
+ - **denoising_meta_values** (`dict`) --
+ Metadata including denoising positive indices, number of groups, and split sizes.
+ """
+
+ if num_denoising_queries <= 0:
+ return None, None, None, None
+
+ num_ground_truths = [len(t["class_labels"]) for t in targets]
+ device = targets[0]["class_labels"].device
+
+ max_gt_num = max(num_ground_truths)
+ if max_gt_num == 0:
+ return None, None, None, None
+
+ num_groups_denoising_queries = num_denoising_queries // max_gt_num
+ num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries
+ # pad gt to max_num of a batch
+ batch_size = len(num_ground_truths)
+
+ input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device)
+ input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device)
+ pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device)
+
+ for i in range(batch_size):
+ num_gt = num_ground_truths[i]
+ if num_gt > 0:
+ input_query_class[i, :num_gt] = targets[i]["class_labels"]
+ input_query_bbox[i, :num_gt] = targets[i]["boxes"]
+ pad_gt_mask[i, :num_gt] = 1
+ # each group has positive and negative queries.
+ input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries])
+ input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1])
+ pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries])
+ # positive and negative mask
+ negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device)
+ negative_gt_mask[:, max_gt_num:] = 1
+ negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1])
+ positive_gt_mask = 1 - negative_gt_mask
+ # contrastive denoising training positive index
+ positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
+ denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
+ denoise_positive_idx = torch.split(
+ denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths]
+ )
+ # total denoising queries
+ num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries)
+
+ if label_noise_ratio > 0:
+ mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
+ # randomly put a new one here
+ new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
+ input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
+
+ if box_noise_scale > 0:
+ known_bbox = center_to_corners_format(input_query_bbox)
+ diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
+ rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
+ rand_part = torch.rand_like(input_query_bbox)
+ rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
+ rand_part *= rand_sign
+ known_bbox += rand_part * diff
+ known_bbox.clip_(min=0.0, max=1.0)
+ input_query_bbox = corners_to_center_format(known_bbox)
+ input_query_bbox = inverse_sigmoid(input_query_bbox)
+
+ input_query_class = class_embed(input_query_class)
+
+ target_size = num_denoising_queries + num_queries
+ attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
+ # match query cannot see the reconstruction
+ attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
+
+ # reconstructions cannot see each other
+ for i in range(num_groups_denoising_queries):
+ idx_block_start = max_gt_num * 2 * i
+ idx_block_end = max_gt_num * 2 * (i + 1)
+ attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
+ attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
+
+ denoising_meta_values = {
+ "dn_positive_idx": denoise_positive_idx,
+ "dn_num_group": num_groups_denoising_queries,
+ "dn_num_split": [num_denoising_queries, num_queries],
+ }
+
+ return input_query_class, input_query_bbox, attn_mask, denoising_meta_values
+
+
+class RTDetrConvEncoder(nn.Module):
+ """
+ Convolutional backbone using the modeling_rt_detr_resnet.py.
+
+ nn.BatchNorm2d layers are replaced by RTDetrFrozenBatchNorm2d as defined above.
+ https://github.com/lyuwenyu/RT-DETR/blob/main/rtdetr_pytorch/src/nn/backbone/presnet.py#L142
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ backbone = load_backbone(config)
+
+ if config.freeze_backbone_batch_norms:
+ # replace batch norm by frozen batch norm
+ with torch.no_grad():
+ replace_batch_norm(backbone)
+ self.model = backbone
+ self.intermediate_channel_sizes = self.model.channels
+
+ def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
+ # send pixel_values through the model to get list of feature maps
+ features = self.model(pixel_values).feature_maps
+
+ out = []
+ for feature_map in features:
+ # downsample pixel_mask to match shape of corresponding feature_map
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
+ out.append((feature_map, mask))
+ return out
+
+
+class RTDetrConvNormLayer(nn.Module):
+ def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding=(kernel_size - 1) // 2 if padding is None else padding,
+ bias=False,
+ )
+ self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
+
+ def forward(self, hidden_state):
+ hidden_state = self.conv(hidden_state)
+ hidden_state = self.norm(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class RTDetrEncoderLayer(nn.Module):
+ def __init__(self, config: RTDetrConfig):
+ super().__init__()
+ self.normalize_before = config.normalize_before
+
+ # self-attention
+ self.self_attn = RTDetrMultiheadAttention(
+ embed_dim=config.encoder_hidden_dim,
+ num_heads=config.num_attention_heads,
+ dropout=config.dropout,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.encoder_activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim)
+ self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ **kwargs,
+ ):
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
+ values.
+ position_embeddings (`torch.FloatTensor`, *optional*):
+ Object queries (also called content embeddings), to be added to the hidden states.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ if self.normalize_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_embeddings=position_embeddings,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ if not self.normalize_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ if self.normalize_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+ residual = hidden_states
+
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+
+ hidden_states = self.fc2(hidden_states)
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ hidden_states = residual + hidden_states
+ if not self.normalize_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ if self.training:
+ if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class RTDetrRepVggBlock(nn.Module):
+ """
+ RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
+ """
+
+ def __init__(self, config: RTDetrConfig):
+ super().__init__()
+
+ activation = config.activation_function
+ hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion)
+ self.conv1 = RTDetrConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1)
+ self.conv2 = RTDetrConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0)
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
+
+ def forward(self, x):
+ y = self.conv1(x) + self.conv2(x)
+ return self.activation(y)
+
+
+class RTDetrCSPRepLayer(nn.Module):
+ """
+ Cross Stage Partial (CSP) network layer with RepVGG blocks.
+ """
+
+ def __init__(self, config: RTDetrConfig):
+ super().__init__()
+
+ in_channels = config.encoder_hidden_dim * 2
+ out_channels = config.encoder_hidden_dim
+ num_blocks = 3
+ activation = config.activation_function
+
+ hidden_channels = int(out_channels * config.hidden_expansion)
+ self.conv1 = RTDetrConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
+ self.conv2 = RTDetrConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
+ self.bottlenecks = nn.Sequential(*[RTDetrRepVggBlock(config) for _ in range(num_blocks)])
+ if hidden_channels != out_channels:
+ self.conv3 = RTDetrConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
+ else:
+ self.conv3 = nn.Identity()
+
+ def forward(self, hidden_state):
+ hidden_state_1 = self.conv1(hidden_state)
+ hidden_state_1 = self.bottlenecks(hidden_state_1)
+ hidden_state_2 = self.conv2(hidden_state)
+ return self.conv3(hidden_state_1 + hidden_state_2)
+
+
+# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->RTDetr
+class RTDetrMultiscaleDeformableAttention(nn.Module):
+ """
+ Multiscale deformable attention as proposed in Deformable DETR.
+ """
+
+ def __init__(self, config: RTDetrConfig, num_heads: int, n_points: int):
+ super().__init__()
+
+ self.attn = MultiScaleDeformableAttention()
+
+ if config.d_model % num_heads != 0:
+ raise ValueError(
+ f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
+ )
+ dim_per_head = config.d_model // num_heads
+ # check if dim_per_head is power of 2
+ if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
+ warnings.warn(
+ "You'd better set embed_dim (d_model) in RTDetrMultiscaleDeformableAttention to make the"
+ " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
+ " implementation."
+ )
+
+ self.im2col_step = 64
+
+ self.d_model = config.d_model
+ self.n_levels = config.num_feature_levels
+ self.n_heads = num_heads
+ self.n_points = n_points
+
+ self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
+ self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
+ self.value_proj = nn.Linear(config.d_model, config.d_model)
+ self.output_proj = nn.Linear(config.d_model, config.d_model)
+
+ self.disable_custom_kernels = config.disable_custom_kernels
+
+ def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
+ return tensor if position_embeddings is None else tensor + position_embeddings
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ position_embeddings: Optional[torch.Tensor] = None,
+ reference_points=None,
+ spatial_shapes=None,
+ spatial_shapes_list=None,
+ level_start_index=None,
+ output_attentions: bool = False,
+ ):
+ # add position embeddings to the hidden states before projecting to queries and keys
+ if position_embeddings is not None:
+ hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
+
+ batch_size, num_queries, _ = hidden_states.shape
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
+ total_elements = sum(height * width for height, width in spatial_shapes_list)
+ if total_elements != sequence_length:
+ raise ValueError(
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
+ )
+
+ value = self.value_proj(encoder_hidden_states)
+ if attention_mask is not None:
+ # we invert the attention_mask
+ value = value.masked_fill(~attention_mask[..., None], float(0))
+ value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
+ sampling_offsets = self.sampling_offsets(hidden_states).view(
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
+ )
+ attention_weights = self.attention_weights(hidden_states).view(
+ batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
+ )
+ attention_weights = F.softmax(attention_weights, -1).view(
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
+ )
+ # batch_size, num_queries, n_heads, n_levels, n_points, 2
+ num_coordinates = reference_points.shape[-1]
+ if num_coordinates == 2:
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :]
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+ )
+ elif num_coordinates == 4:
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :2]
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
+ )
+ else:
+ raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
+
+ output = self.attn(
+ value,
+ spatial_shapes,
+ spatial_shapes_list,
+ level_start_index,
+ sampling_locations,
+ attention_weights,
+ self.im2col_step,
+ )
+
+ output = self.output_proj(output)
+
+ return output, attention_weights
+
+
+class RTDetrMultiheadAttention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper.
+
+ Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ if self.head_dim * num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
+ return tensor if position_embeddings is None else tensor + position_embeddings
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, target_len, embed_dim = hidden_states.size()
+ # add position embeddings to the hidden states before projecting to queries and keys
+ if position_embeddings is not None:
+ hidden_states_original = hidden_states
+ hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
+
+ # get queries, keys and values
+ query_states = self.q_proj(hidden_states) * self.scaling
+ key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size)
+ value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size)
+
+ proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
+ query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ source_len = key_states.size(1)
+
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
+ raise ValueError(
+ f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
+ attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size())
+
+ if attention_mask is not None:
+ if attention_mask.size() != (batch_size, 1, target_len, source_len):
+ raise ValueError(
+ f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
+ f" {attention_mask.size()}"
+ )
+ if attention_mask.dtype == torch.bool:
+ attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
+ attention_mask, -torch.inf
+ )
+ attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
+ attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
+ attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+class RTDetrDecoderLayer(nn.Module):
+ def __init__(self, config: RTDetrConfig):
+ super().__init__()
+ # self-attention
+ self.self_attn = RTDetrMultiheadAttention(
+ embed_dim=config.d_model,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.decoder_activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
+ # cross-attention
+ self.encoder_attn = RTDetrMultiscaleDeformableAttention(
+ config,
+ num_heads=config.decoder_attention_heads,
+ n_points=config.decoder_n_points,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
+ # feedforward neural networks
+ self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)
+ self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[torch.Tensor] = None,
+ reference_points=None,
+ spatial_shapes=None,
+ spatial_shapes_list=None,
+ level_start_index=None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ):
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(seq_len, batch, embed_dim)`.
+ position_embeddings (`torch.FloatTensor`, *optional*):
+ Position embeddings that are added to the queries and keys in the self-attention layer.
+ reference_points (`torch.FloatTensor`, *optional*):
+ Reference points.
+ spatial_shapes (`torch.LongTensor`, *optional*):
+ Spatial shapes.
+ level_start_index (`torch.LongTensor`, *optional*):
+ Level start index.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
+ values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=encoder_attention_mask,
+ position_embeddings=position_embeddings,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ second_residual = hidden_states
+
+ # Cross-Attention
+ cross_attn_weights = None
+ hidden_states, cross_attn_weights = self.encoder_attn(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ position_embeddings=position_embeddings,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ level_start_index=level_start_index,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = second_residual + hidden_states
+
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ return outputs
+
+
+@auto_docstring
+class RTDetrPreTrainedModel(PreTrainedModel):
+ config: RTDetrConfig
+ base_model_prefix = "rt_detr"
+ main_input_name = "pixel_values"
+ _no_split_modules = [r"RTDetrHybridEncoder", r"RTDetrDecoderLayer"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (RTDetrForObjectDetection, RTDetrDecoder)):
+ if module.class_embed is not None:
+ for layer in module.class_embed:
+ prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
+ bias = float(-math.log((1 - prior_prob) / prior_prob))
+ nn.init.xavier_uniform_(layer.weight)
+ nn.init.constant_(layer.bias, bias)
+
+ if module.bbox_embed is not None:
+ for layer in module.bbox_embed:
+ nn.init.constant_(layer.layers[-1].weight, 0)
+ nn.init.constant_(layer.layers[-1].bias, 0)
+
+ elif isinstance(module, RTDetrMultiscaleDeformableAttention):
+ nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
+ default_dtype = torch.get_default_dtype()
+ thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
+ 2.0 * math.pi / module.n_heads
+ )
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+ .view(module.n_heads, 1, 1, 2)
+ .repeat(1, module.n_levels, module.n_points, 1)
+ )
+ for i in range(module.n_points):
+ grid_init[:, :, i, :] *= i + 1
+ with torch.no_grad():
+ module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+ nn.init.constant_(module.attention_weights.weight.data, 0.0)
+ nn.init.constant_(module.attention_weights.bias.data, 0.0)
+ nn.init.xavier_uniform_(module.value_proj.weight.data)
+ nn.init.constant_(module.value_proj.bias.data, 0.0)
+ nn.init.xavier_uniform_(module.output_proj.weight.data)
+ nn.init.constant_(module.output_proj.bias.data, 0.0)
+
+ elif isinstance(module, RTDetrModel):
+ prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
+ bias = float(-math.log((1 - prior_prob) / prior_prob))
+ nn.init.xavier_uniform_(module.enc_score_head.weight)
+ nn.init.constant_(module.enc_score_head.bias, bias)
+
+ elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+
+ if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
+ nn.init.xavier_uniform_(module.weight_embedding.weight)
+ if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:
+ nn.init.xavier_uniform_(module.denoising_class_embed.weight)
+
+
+class RTDetrEncoder(nn.Module):
+ def __init__(self, config: RTDetrConfig):
+ super().__init__()
+
+ self.layers = nn.ModuleList([RTDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
+
+ def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor:
+ hidden_states = src
+ for layer in self.layers:
+ hidden_states = layer(
+ hidden_states,
+ attention_mask=src_mask,
+ position_embeddings=pos_embed,
+ output_attentions=output_attentions,
+ )
+ return hidden_states
+
+
+class RTDetrHybridEncoder(nn.Module):
+ """
+ Decoder consisting of a projection layer, a set of `RTDetrEncoder`, a top-down Feature Pyramid Network
+ (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://huggingface.co/papers/2304.08069
+
+ Args:
+ config: RTDetrConfig
+ """
+
+ def __init__(self, config: RTDetrConfig):
+ super().__init__()
+ self.config = config
+ self.in_channels = config.encoder_in_channels
+ self.feat_strides = config.feat_strides
+ self.encoder_hidden_dim = config.encoder_hidden_dim
+ self.encode_proj_layers = config.encode_proj_layers
+ self.positional_encoding_temperature = config.positional_encoding_temperature
+ self.eval_size = config.eval_size
+ self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
+ self.out_strides = self.feat_strides
+ self.num_fpn_stages = len(self.in_channels) - 1
+ self.num_pan_stages = len(self.in_channels) - 1
+ activation = config.activation_function
+
+ # encoder transformer
+ self.encoder = nn.ModuleList([RTDetrEncoder(config) for _ in range(len(self.encode_proj_layers))])
+
+ # top-down FPN
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_blocks = nn.ModuleList()
+ for _ in range(self.num_fpn_stages):
+ lateral_conv = RTDetrConvNormLayer(
+ config,
+ in_channels=self.encoder_hidden_dim,
+ out_channels=self.encoder_hidden_dim,
+ kernel_size=1,
+ stride=1,
+ activation=activation,
+ )
+ fpn_block = RTDetrCSPRepLayer(config)
+ self.lateral_convs.append(lateral_conv)
+ self.fpn_blocks.append(fpn_block)
+
+ # bottom-up PAN
+ self.downsample_convs = nn.ModuleList()
+ self.pan_blocks = nn.ModuleList()
+ for _ in range(self.num_pan_stages):
+ downsample_conv = RTDetrConvNormLayer(
+ config,
+ in_channels=self.encoder_hidden_dim,
+ out_channels=self.encoder_hidden_dim,
+ kernel_size=3,
+ stride=2,
+ activation=activation,
+ )
+ pan_block = RTDetrCSPRepLayer(config)
+ self.downsample_convs.append(downsample_conv)
+ self.pan_blocks.append(pan_block)
+
+ @staticmethod
+ def build_2d_sincos_position_embedding(
+ width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
+ ):
+ grid_w = torch.arange(torch_int(width), device=device).to(dtype)
+ grid_h = torch.arange(torch_int(height), device=device).to(dtype)
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
+ if embed_dim % 4 != 0:
+ raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
+ pos_dim = embed_dim // 4
+ omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
+ omega = 1.0 / (temperature**omega)
+
+ out_w = grid_w.flatten()[..., None] @ omega[None]
+ out_h = grid_h.flatten()[..., None] @ omega[None]
+
+ return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :]
+
+ def forward(
+ self,
+ inputs_embeds=None,
+ attention_mask=None,
+ position_embeddings=None,
+ spatial_shapes=None,
+ level_start_index=None,
+ valid_ratios=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
+ - 1 for pixel features that are real (i.e. **not masked**),
+ - 0 for pixel features that are padding (i.e. **masked**).
+ [What are attention masks?](../glossary#attention-mask)
+ position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Position embeddings that are added to the queries and keys in each self-attention layer.
+ spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
+ Spatial shapes of each feature map.
+ level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
+ Starting index of each feature map.
+ valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
+ Ratio of valid area in each feature level.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states = inputs_embeds
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # encoder
+ if self.config.encoder_layers > 0:
+ for i, enc_ind in enumerate(self.encode_proj_layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states[enc_ind],)
+ height, width = hidden_states[enc_ind].shape[2:]
+ # flatten [batch, channel, height, width] to [batch, height*width, channel]
+ src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1)
+ if self.training or self.eval_size is None:
+ pos_embed = self.build_2d_sincos_position_embedding(
+ width,
+ height,
+ self.encoder_hidden_dim,
+ self.positional_encoding_temperature,
+ device=src_flatten.device,
+ dtype=src_flatten.dtype,
+ )
+ else:
+ pos_embed = None
+
+ layer_outputs = self.encoder[i](
+ src_flatten,
+ pos_embed=pos_embed,
+ output_attentions=output_attentions,
+ )
+ hidden_states[enc_ind] = (
+ layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous()
+ )
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states[enc_ind],)
+
+ # top-down FPN
+ fpn_feature_maps = [hidden_states[-1]]
+ for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
+ backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1]
+ top_fpn_feature_map = fpn_feature_maps[-1]
+ # apply lateral block
+ top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
+ fpn_feature_maps[-1] = top_fpn_feature_map
+ # apply fpn block
+ top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest")
+ fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1)
+ new_fpn_feature_map = fpn_block(fused_feature_map)
+ fpn_feature_maps.append(new_fpn_feature_map)
+
+ fpn_feature_maps = fpn_feature_maps[::-1]
+
+ # bottom-up PAN
+ pan_feature_maps = [fpn_feature_maps[0]]
+ for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)):
+ top_pan_feature_map = pan_feature_maps[-1]
+ fpn_feature_map = fpn_feature_maps[idx + 1]
+ downsampled_feature_map = downsample_conv(top_pan_feature_map)
+ fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1)
+ new_pan_feature_map = pan_block(fused_feature_map)
+ pan_feature_maps.append(new_pan_feature_map)
+
+ if not return_dict:
+ return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class RTDetrDecoder(RTDetrPreTrainedModel):
+ def __init__(self, config: RTDetrConfig):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+ self.layers = nn.ModuleList([RTDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
+ self.query_pos_head = RTDetrMLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2)
+
+ # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
+ self.bbox_embed = None
+ self.class_embed = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ position_embeddings=None,
+ reference_points=None,
+ spatial_shapes=None,
+ spatial_shapes_list=None,
+ level_start_index=None,
+ valid_ratios=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
+ The query embeddings that are passed into the decoder.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
+ in `[0, 1]`:
+ - 1 for pixels that are real (i.e. **not masked**),
+ - 0 for pixels that are padding (i.e. **masked**).
+ position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
+ Position embeddings that are added to the queries and keys in each self-attention layer.
+ reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
+ Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
+ spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
+ Spatial shapes of the feature maps.
+ level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
+ Indexes for the start of each feature level. In range `[0, sequence_length]`.
+ valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
+ Ratio of valid area in each feature level.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ intermediate = ()
+ intermediate_reference_points = ()
+ intermediate_logits = ()
+
+ reference_points = F.sigmoid(reference_points)
+
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L252
+ for idx, decoder_layer in enumerate(self.layers):
+ reference_points_input = reference_points.unsqueeze(2)
+ position_embeddings = self.query_pos_head(reference_points)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ encoder_hidden_states=encoder_hidden_states,
+ reference_points=reference_points_input,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ level_start_index=level_start_index,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ # hack implementation for iterative bounding box refinement
+ if self.bbox_embed is not None:
+ predicted_corners = self.bbox_embed[idx](hidden_states)
+ new_reference_points = F.sigmoid(predicted_corners + inverse_sigmoid(reference_points))
+ reference_points = new_reference_points.detach()
+
+ intermediate += (hidden_states,)
+ intermediate_reference_points += (
+ (new_reference_points,) if self.bbox_embed is not None else (reference_points,)
+ )
+
+ if self.class_embed is not None:
+ logits = self.class_embed[idx](hidden_states)
+ intermediate_logits += (logits,)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ # Keep batch_size as first dimension
+ intermediate = torch.stack(intermediate, dim=1)
+ intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
+ if self.class_embed is not None:
+ intermediate_logits = torch.stack(intermediate_logits, dim=1)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ intermediate,
+ intermediate_logits,
+ intermediate_reference_points,
+ all_hidden_states,
+ all_self_attns,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return RTDetrDecoderOutput(
+ last_hidden_state=hidden_states,
+ intermediate_hidden_states=intermediate,
+ intermediate_logits=intermediate_logits,
+ intermediate_reference_points=intermediate_reference_points,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
+class RTDetrMLPPredictionHead(nn.Module):
+ """
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
+ height and width of a bounding box w.r.t. an image.
+
+ Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
+ Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_paddle/ppdet/modeling/transformers/utils.py#L453
+
+ """
+
+ def __init__(self, config, input_dim, d_model, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [d_model] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+@auto_docstring(
+ custom_intro="""
+ RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top.
+ """
+)
+class RTDetrModel(RTDetrPreTrainedModel):
+ def __init__(self, config: RTDetrConfig):
+ super().__init__(config)
+
+ # Create backbone
+ self.backbone = RTDetrConvEncoder(config)
+ intermediate_channel_sizes = self.backbone.intermediate_channel_sizes
+
+ # Create encoder input projection layers
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py#L212
+ num_backbone_outs = len(intermediate_channel_sizes)
+ encoder_input_proj_list = []
+ for _ in range(num_backbone_outs):
+ in_channels = intermediate_channel_sizes[_]
+ encoder_input_proj_list.append(
+ nn.Sequential(
+ nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False),
+ nn.BatchNorm2d(config.encoder_hidden_dim),
+ )
+ )
+ self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list)
+
+ # Create encoder
+ self.encoder = RTDetrHybridEncoder(config)
+
+ # denoising part
+ if config.num_denoising > 0:
+ self.denoising_class_embed = nn.Embedding(
+ config.num_labels + 1, config.d_model, padding_idx=config.num_labels
+ )
+
+ # decoder embedding
+ if config.learn_initial_query:
+ self.weight_embedding = nn.Embedding(config.num_queries, config.d_model)
+
+ # encoder head
+ self.enc_output = nn.Sequential(
+ nn.Linear(config.d_model, config.d_model),
+ nn.LayerNorm(config.d_model, eps=config.layer_norm_eps),
+ )
+ self.enc_score_head = nn.Linear(config.d_model, config.num_labels)
+ self.enc_bbox_head = RTDetrMLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3)
+
+ # init encoder output anchors and valid_mask
+ if config.anchor_image_size:
+ self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype)
+
+ # Create decoder input projection layers
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
+ num_backbone_outs = len(config.decoder_in_channels)
+ decoder_input_proj_list = []
+ for _ in range(num_backbone_outs):
+ in_channels = config.decoder_in_channels[_]
+ decoder_input_proj_list.append(
+ nn.Sequential(
+ nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False),
+ nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
+ )
+ )
+ for _ in range(config.num_feature_levels - num_backbone_outs):
+ decoder_input_proj_list.append(
+ nn.Sequential(
+ nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
+ )
+ )
+ in_channels = config.d_model
+ self.decoder_input_proj = nn.ModuleList(decoder_input_proj_list)
+
+ # decoder
+ self.decoder = RTDetrDecoder(config)
+
+ self.post_init()
+
+ def get_encoder(self):
+ return self.encoder
+
+ def freeze_backbone(self):
+ for param in self.backbone.parameters():
+ param.requires_grad_(False)
+
+ def unfreeze_backbone(self):
+ for param in self.backbone.parameters():
+ param.requires_grad_(True)
+
+ @compile_compatible_method_lru_cache(maxsize=32)
+ def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
+ if spatial_shapes is None:
+ spatial_shapes = [
+ [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)]
+ for s in self.config.feat_strides
+ ]
+ anchors = []
+ for level, (height, width) in enumerate(spatial_shapes):
+ grid_y, grid_x = torch.meshgrid(
+ torch.arange(end=height, device=device).to(dtype),
+ torch.arange(end=width, device=device).to(dtype),
+ indexing="ij",
+ )
+ grid_xy = torch.stack([grid_x, grid_y], -1)
+ grid_xy = grid_xy.unsqueeze(0) + 0.5
+ grid_xy[..., 0] /= width
+ grid_xy[..., 1] /= height
+ wh = torch.ones_like(grid_xy) * grid_size * (2.0**level)
+ anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4))
+ # define the valid range for anchor coordinates
+ eps = 1e-2
+ anchors = torch.concat(anchors, 1)
+ valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
+ anchors = torch.log(anchors / (1 - anchors))
+ anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device))
+
+ return anchors, valid_mask
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[list[dict]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.FloatTensor], RTDetrModelOutput]:
+ r"""
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
+ can choose to directly pass a flattened representation of an image.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
+ embedded representation.
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, RTDetrModel
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
+ >>> model = RTDetrModel.from_pretrained("PekingU/rtdetr_r50vd")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ >>> list(last_hidden_states.shape)
+ [1, 300, 256]
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size, num_channels, height, width = pixel_values.shape
+ device = pixel_values.device
+
+ if pixel_mask is None:
+ pixel_mask = torch.ones(((batch_size, height, width)), device=device)
+
+ features = self.backbone(pixel_values, pixel_mask)
+
+ proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ proj_feats,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if output_hidden_states else None,
+ attentions=encoder_outputs[2]
+ if len(encoder_outputs) > 2
+ else encoder_outputs[1]
+ if output_attentions
+ else None,
+ )
+
+ # Equivalent to def _get_encoder_input
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
+ sources = []
+ for level, source in enumerate(encoder_outputs[0]):
+ sources.append(self.decoder_input_proj[level](source))
+
+ # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
+ if self.config.num_feature_levels > len(sources):
+ _len_sources = len(sources)
+ sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1])
+ for i in range(_len_sources + 1, self.config.num_feature_levels):
+ sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1]))
+
+ # Prepare encoder inputs (by flattening)
+ source_flatten = []
+ spatial_shapes_list = []
+ spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long)
+ for level, source in enumerate(sources):
+ height, width = source.shape[-2:]
+ spatial_shapes[level, 0] = height
+ spatial_shapes[level, 1] = width
+ spatial_shapes_list.append((height, width))
+ source = source.flatten(2).transpose(1, 2)
+ source_flatten.append(source)
+ source_flatten = torch.cat(source_flatten, 1)
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
+
+ # prepare denoising training
+ if self.training and self.config.num_denoising > 0 and labels is not None:
+ (
+ denoising_class,
+ denoising_bbox_unact,
+ attention_mask,
+ denoising_meta_values,
+ ) = get_contrastive_denoising_training_group(
+ targets=labels,
+ num_classes=self.config.num_labels,
+ num_queries=self.config.num_queries,
+ class_embed=self.denoising_class_embed,
+ num_denoising_queries=self.config.num_denoising,
+ label_noise_ratio=self.config.label_noise_ratio,
+ box_noise_scale=self.config.box_noise_scale,
+ )
+ else:
+ denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None
+
+ batch_size = len(source_flatten)
+ device = source_flatten.device
+ dtype = source_flatten.dtype
+
+ # prepare input for decoder
+ if self.training or self.config.anchor_image_size is None:
+ # Pass spatial_shapes as tuple to make it hashable and make sure
+ # lru_cache is working for generate_anchors()
+ spatial_shapes_tuple = tuple(spatial_shapes_list)
+ anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype)
+ else:
+ anchors, valid_mask = self.anchors, self.valid_mask
+ anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype)
+
+ # use the valid_mask to selectively retain values in the feature map where the mask is `True`
+ memory = valid_mask.to(source_flatten.dtype) * source_flatten
+
+ output_memory = self.enc_output(memory)
+
+ enc_outputs_class = self.enc_score_head(output_memory)
+ enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors
+
+ _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1)
+
+ reference_points_unact = enc_outputs_coord_logits.gather(
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1])
+ )
+
+ enc_topk_bboxes = F.sigmoid(reference_points_unact)
+ if denoising_bbox_unact is not None:
+ reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)
+
+ enc_topk_logits = enc_outputs_class.gather(
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])
+ )
+
+ # extract region features
+ if self.config.learn_initial_query:
+ target = self.weight_embedding.tile([batch_size, 1, 1])
+ else:
+ target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
+ target = target.detach()
+
+ if denoising_class is not None:
+ target = torch.concat([denoising_class, target], 1)
+
+ init_reference_points = reference_points_unact.detach()
+
+ # decoder
+ decoder_outputs = self.decoder(
+ inputs_embeds=target,
+ encoder_hidden_states=source_flatten,
+ encoder_attention_mask=attention_mask,
+ reference_points=init_reference_points,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ level_start_index=level_start_index,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ enc_outputs = tuple(
+ value
+ for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits]
+ if value is not None
+ )
+ dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values])
+ tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs
+
+ return tuple_outputs
+
+ return RTDetrModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
+ intermediate_logits=decoder_outputs.intermediate_logits,
+ intermediate_reference_points=decoder_outputs.intermediate_reference_points,
+ intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners,
+ initial_reference_points=decoder_outputs.initial_reference_points,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ init_reference_points=init_reference_points,
+ enc_topk_logits=enc_topk_logits,
+ enc_topk_bboxes=enc_topk_bboxes,
+ enc_outputs_class=enc_outputs_class,
+ enc_outputs_coord_logits=enc_outputs_coord_logits,
+ denoising_meta_values=denoising_meta_values,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further
+ decoded into scores and classes.
+ """
+)
+class RTDetrForObjectDetection(RTDetrPreTrainedModel):
+ # When using clones, all layers > 0 will be clones, but layer 0 *is* required
+ _tied_weights_keys = ["bbox_embed", "class_embed"]
+ # We can't initialize the model on meta device as some weights are modified during the initialization
+ _no_split_modules = None
+
+ def __init__(self, config: RTDetrConfig):
+ super().__init__(config)
+
+ # RTDETR encoder-decoder model
+ self.model = RTDetrModel(config)
+
+ # Detection heads on top
+ self.class_embed = partial(nn.Linear, config.d_model, config.num_labels)
+ self.bbox_embed = partial(RTDetrMLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3)
+
+ # if two-stage, the last class_embed and bbox_embed is for region proposal generation
+ num_pred = config.decoder_layers
+ if config.with_box_refine:
+ self.class_embed = _get_clones(self.class_embed, num_pred)
+ self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
+ else:
+ self.class_embed = nn.ModuleList([self.class_embed() for _ in range(num_pred)])
+ self.bbox_embed = nn.ModuleList([self.bbox_embed() for _ in range(num_pred)])
+
+ # hack implementation for iterative bounding box refinement
+ self.model.decoder.class_embed = self.class_embed
+ self.model.decoder.bbox_embed = self.bbox_embed
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @torch.jit.unused
+ def _set_aux_loss(self, outputs_class, outputs_coord):
+ # this is a workaround to make torchscript happy, as torchscript
+ # doesn't support dictionary with non-homogeneous values, such
+ # as a dict having both a Tensor and a list.
+ return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[list[dict]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.FloatTensor], RTDetrObjectDetectionOutput]:
+ r"""
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
+ can choose to directly pass a flattened representation of an image.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
+ embedded representation.
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import RTDetrImageProcessor, RTDetrForObjectDetection
+ >>> from PIL import Image
+ >>> import requests
+ >>> import torch
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
+ >>> model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")
+
+ >>> # prepare image for the model
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> # forward pass
+ >>> outputs = model(**inputs)
+
+ >>> logits = outputs.logits
+ >>> list(logits.shape)
+ [1, 300, 80]
+
+ >>> boxes = outputs.pred_boxes
+ >>> list(boxes.shape)
+ [1, 300, 4]
+
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
+ >>> target_sizes = torch.tensor([image.size[::-1]])
+ >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
+ ... 0
+ ... ]
+
+ >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
+ ... box = [round(i, 2) for i in box.tolist()]
+ ... print(
+ ... f"Detected {model.config.id2label[label.item()]} with confidence "
+ ... f"{round(score.item(), 3)} at location {box}"
+ ... )
+ Detected sofa with confidence 0.97 at location [0.14, 0.38, 640.13, 476.21]
+ Detected cat with confidence 0.96 at location [343.38, 24.28, 640.14, 371.5]
+ Detected cat with confidence 0.958 at location [13.23, 54.18, 318.98, 472.22]
+ Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48]
+ Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99]
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ pixel_values,
+ pixel_mask=pixel_mask,
+ encoder_outputs=encoder_outputs,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ labels=labels,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ denoising_meta_values = (
+ outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None
+ )
+
+ outputs_class = outputs.intermediate_logits if return_dict else outputs[2]
+ outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3]
+ predicted_corners = outputs.intermediate_predicted_corners if return_dict else outputs[4]
+ initial_reference_points = outputs.initial_reference_points if return_dict else outputs[5]
+
+ logits = outputs_class[:, -1]
+ pred_boxes = outputs_coord[:, -1]
+
+ loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
+ if labels is not None:
+ enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5]
+ enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4]
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
+ logits,
+ labels,
+ self.device,
+ pred_boxes,
+ self.config,
+ outputs_class,
+ outputs_coord,
+ enc_topk_logits=enc_topk_logits,
+ enc_topk_bboxes=enc_topk_bboxes,
+ denoising_meta_values=denoising_meta_values,
+ predicted_corners=predicted_corners,
+ initial_reference_points=initial_reference_points,
+ **kwargs,
+ )
+
+ if not return_dict:
+ if auxiliary_outputs is not None:
+ output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs
+ else:
+ output = (logits, pred_boxes) + outputs
+ return ((loss, loss_dict) + output) if loss is not None else output
+
+ return RTDetrObjectDetectionOutput(
+ loss=loss,
+ loss_dict=loss_dict,
+ logits=logits,
+ pred_boxes=pred_boxes,
+ auxiliary_outputs=auxiliary_outputs,
+ last_hidden_state=outputs.last_hidden_state,
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
+ intermediate_logits=outputs.intermediate_logits,
+ intermediate_reference_points=outputs.intermediate_reference_points,
+ intermediate_predicted_corners=outputs.intermediate_predicted_corners,
+ initial_reference_points=outputs.initial_reference_points,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ init_reference_points=outputs.init_reference_points,
+ enc_topk_logits=outputs.enc_topk_logits,
+ enc_topk_bboxes=outputs.enc_topk_bboxes,
+ enc_outputs_class=outputs.enc_outputs_class,
+ enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
+ denoising_meta_values=outputs.denoising_meta_values,
+ )
+
+
+__all__ = [
+ "RTDetrForObjectDetection",
+ "RTDetrModel",
+ "RTDetrPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/modeling_rt_detr_resnet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/modeling_rt_detr_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..21781dc3573fd8a16f44640ea790109aed720389
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/modeling_rt_detr_resnet.py
@@ -0,0 +1,399 @@
+# coding=utf-8
+# Copyright 2024 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+PyTorch RTDetr specific ResNet model. The main difference between hugginface ResNet model is that this RTDetrResNet model forces to use shortcut at the first layer in the resnet-18/34 models.
+See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L126 for details.
+"""
+
+import math
+from typing import Optional
+
+from torch import Tensor, nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BackboneOutput, BaseModelOutputWithNoAttention
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging
+from ...utils.backbone_utils import BackboneMixin
+from .configuration_rt_detr_resnet import RTDetrResNetConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from transformers.models.resnet.modeling_resnet.ResNetConvLayer -> RTDetrResNetConvLayer
+class RTDetrResNetConvLayer(nn.Module):
+ def __init__(
+ self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
+ ):
+ super().__init__()
+ self.convolution = nn.Conv2d(
+ in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False
+ )
+ self.normalization = nn.BatchNorm2d(out_channels)
+ self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
+
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = self.convolution(input)
+ hidden_state = self.normalization(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class RTDetrResNetEmbeddings(nn.Module):
+ """
+ ResNet Embeddings (stem) composed of a deep aggressive convolution.
+ """
+
+ def __init__(self, config: RTDetrResNetConfig):
+ super().__init__()
+ self.embedder = nn.Sequential(
+ *[
+ RTDetrResNetConvLayer(
+ config.num_channels,
+ config.embedding_size // 2,
+ kernel_size=3,
+ stride=2,
+ activation=config.hidden_act,
+ ),
+ RTDetrResNetConvLayer(
+ config.embedding_size // 2,
+ config.embedding_size // 2,
+ kernel_size=3,
+ stride=1,
+ activation=config.hidden_act,
+ ),
+ RTDetrResNetConvLayer(
+ config.embedding_size // 2,
+ config.embedding_size,
+ kernel_size=3,
+ stride=1,
+ activation=config.hidden_act,
+ ),
+ ]
+ )
+ self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.num_channels = config.num_channels
+
+ def forward(self, pixel_values: Tensor) -> Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ embedding = self.embedder(pixel_values)
+ embedding = self.pooler(embedding)
+ return embedding
+
+
+# Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut -> RTDetrResNetChortCut
+class RTDetrResNetShortCut(nn.Module):
+ """
+ ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
+ downsample the input using `stride=2`.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
+ super().__init__()
+ self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
+ self.normalization = nn.BatchNorm2d(out_channels)
+
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = self.convolution(input)
+ hidden_state = self.normalization(hidden_state)
+ return hidden_state
+
+
+class RTDetrResNetBasicLayer(nn.Module):
+ """
+ A classic ResNet's residual layer composed by two `3x3` convolutions.
+ See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L34.
+ """
+
+ def __init__(
+ self,
+ config: RTDetrResNetConfig,
+ in_channels: int,
+ out_channels: int,
+ stride: int = 1,
+ should_apply_shortcut: bool = False,
+ ):
+ super().__init__()
+ if in_channels != out_channels:
+ self.shortcut = (
+ nn.Sequential(
+ *[nn.AvgPool2d(2, 2, 0, ceil_mode=True), RTDetrResNetShortCut(in_channels, out_channels, stride=1)]
+ )
+ if should_apply_shortcut
+ else nn.Identity()
+ )
+ else:
+ self.shortcut = (
+ RTDetrResNetShortCut(in_channels, out_channels, stride=stride)
+ if should_apply_shortcut
+ else nn.Identity()
+ )
+ self.layer = nn.Sequential(
+ RTDetrResNetConvLayer(in_channels, out_channels, stride=stride),
+ RTDetrResNetConvLayer(out_channels, out_channels, activation=None),
+ )
+ self.activation = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ residual = hidden_state
+ hidden_state = self.layer(hidden_state)
+ residual = self.shortcut(residual)
+ hidden_state += residual
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class RTDetrResNetBottleNeckLayer(nn.Module):
+ """
+ A classic RTDetrResNet's bottleneck layer composed by three `3x3` convolutions.
+
+ The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
+ convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. If
+ `downsample_in_bottleneck` is true, downsample will be in the first layer instead of the second layer.
+ """
+
+ def __init__(
+ self,
+ config: RTDetrResNetConfig,
+ in_channels: int,
+ out_channels: int,
+ stride: int = 1,
+ ):
+ super().__init__()
+ reduction = 4
+ should_apply_shortcut = in_channels != out_channels or stride != 1
+ reduces_channels = out_channels // reduction
+ if stride == 2:
+ self.shortcut = nn.Sequential(
+ *[
+ nn.AvgPool2d(2, 2, 0, ceil_mode=True),
+ RTDetrResNetShortCut(in_channels, out_channels, stride=1)
+ if should_apply_shortcut
+ else nn.Identity(),
+ ]
+ )
+ else:
+ self.shortcut = (
+ RTDetrResNetShortCut(in_channels, out_channels, stride=stride)
+ if should_apply_shortcut
+ else nn.Identity()
+ )
+ self.layer = nn.Sequential(
+ RTDetrResNetConvLayer(
+ in_channels, reduces_channels, kernel_size=1, stride=stride if config.downsample_in_bottleneck else 1
+ ),
+ RTDetrResNetConvLayer(
+ reduces_channels, reduces_channels, stride=stride if not config.downsample_in_bottleneck else 1
+ ),
+ RTDetrResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None),
+ )
+ self.activation = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ residual = hidden_state
+ hidden_state = self.layer(hidden_state)
+ residual = self.shortcut(residual)
+ hidden_state += residual
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class RTDetrResNetStage(nn.Module):
+ """
+ A RTDetrResNet stage composed by stacked layers.
+ """
+
+ def __init__(
+ self,
+ config: RTDetrResNetConfig,
+ in_channels: int,
+ out_channels: int,
+ stride: int = 2,
+ depth: int = 2,
+ ):
+ super().__init__()
+
+ layer = RTDetrResNetBottleNeckLayer if config.layer_type == "bottleneck" else RTDetrResNetBasicLayer
+
+ if config.layer_type == "bottleneck":
+ first_layer = layer(
+ config,
+ in_channels,
+ out_channels,
+ stride=stride,
+ )
+ else:
+ first_layer = layer(config, in_channels, out_channels, stride=stride, should_apply_shortcut=True)
+ self.layers = nn.Sequential(
+ first_layer, *[layer(config, out_channels, out_channels) for _ in range(depth - 1)]
+ )
+
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = input
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
+
+
+# Copied from transformers.models.resnet.modeling_resnet.ResNetEncoder with ResNet->RTDetrResNet
+class RTDetrResNetEncoder(nn.Module):
+ def __init__(self, config: RTDetrResNetConfig):
+ super().__init__()
+ self.stages = nn.ModuleList([])
+ # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
+ self.stages.append(
+ RTDetrResNetStage(
+ config,
+ config.embedding_size,
+ config.hidden_sizes[0],
+ stride=2 if config.downsample_in_first_stage else 1,
+ depth=config.depths[0],
+ )
+ )
+ in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
+ for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
+ self.stages.append(RTDetrResNetStage(config, in_channels, out_channels, depth=depth))
+
+ def forward(
+ self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
+ ) -> BaseModelOutputWithNoAttention:
+ hidden_states = () if output_hidden_states else None
+
+ for stage_module in self.stages:
+ if output_hidden_states:
+ hidden_states = hidden_states + (hidden_state,)
+
+ hidden_state = stage_module(hidden_state)
+
+ if output_hidden_states:
+ hidden_states = hidden_states + (hidden_state,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_state, hidden_states] if v is not None)
+
+ return BaseModelOutputWithNoAttention(
+ last_hidden_state=hidden_state,
+ hidden_states=hidden_states,
+ )
+
+
+@auto_docstring
+# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RTDetrResNet
+class RTDetrResNetPreTrainedModel(PreTrainedModel):
+ config: RTDetrResNetConfig
+ base_model_prefix = "resnet"
+ main_input_name = "pixel_values"
+ _no_split_modules = ["RTDetrResNetConvLayer", "RTDetrResNetShortCut"]
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Conv2d):
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
+ # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
+ elif isinstance(module, nn.Linear):
+ nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
+ if module.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
+ nn.init.uniform_(module.bias, -bound, bound)
+ elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(module.weight, 1)
+ nn.init.constant_(module.bias, 0)
+
+
+@auto_docstring(
+ custom_intro="""
+ ResNet backbone, to be used with frameworks like RTDETR.
+ """
+)
+class RTDetrResNetBackbone(RTDetrResNetPreTrainedModel, BackboneMixin):
+ has_attentions = False
+
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.num_features = [config.embedding_size] + config.hidden_sizes
+ self.embedder = RTDetrResNetEmbeddings(config)
+ self.encoder = RTDetrResNetEncoder(config)
+
+ # initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
+ ) -> BackboneOutput:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import RTDetrResNetConfig, RTDetrResNetBackbone
+ >>> import torch
+ from ...utils.deprecation import deprecate_kwarg
+ from ...utils.deprecation import deprecate_kwarg
+ from ...utils.deprecation import deprecate_kwarg
+ from ...utils.deprecation import deprecate_kwarg
+ from ...utils.deprecation import deprecate_kwarg
+
+ >>> config = RTDetrResNetConfig()
+ >>> model = RTDetrResNetBackbone(config)
+
+ >>> pixel_values = torch.randn(1, 3, 224, 224)
+
+ >>> with torch.no_grad():
+ ... outputs = model(pixel_values)
+
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 2048, 7, 7]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ embedding_output = self.embedder(pixel_values)
+
+ outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
+
+ hidden_states = outputs.hidden_states
+
+ feature_maps = ()
+ for idx, stage in enumerate(self.stage_names):
+ if stage in self.out_features:
+ feature_maps += (hidden_states[idx],)
+
+ if not return_dict:
+ output = (feature_maps,)
+ if output_hidden_states:
+ output += (outputs.hidden_states,)
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=None,
+ )
+
+
+__all__ = [
+ "RTDetrResNetBackbone",
+ "RTDetrResNetPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/modular_rt_detr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/modular_rt_detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..61bd055144f0504aaac3ab256b3565f82411a87d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/rt_detr/modular_rt_detr.py
@@ -0,0 +1,355 @@
+import pathlib
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from transformers.models.detr.image_processing_detr_fast import DetrFastImageProcessorKwargs, DetrImageProcessorFast
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, get_max_height_width
+from ...image_transforms import center_to_corners_format
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ AnnotationFormat,
+ AnnotationType,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ validate_annotations,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ logging,
+ requires_backends,
+)
+
+
+logger = logging.get_logger(__name__)
+
+SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,)
+
+
+def prepare_coco_detection_annotation(
+ image,
+ target,
+ return_segmentation_masks: bool = False,
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
+):
+ """
+ Convert the target in COCO format into the format expected by RT-DETR.
+ """
+ image_height, image_width = image.size()[-2:]
+
+ image_id = target["image_id"]
+ image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device)
+
+ # Get all COCO annotations for the given image.
+ annotations = target["annotations"]
+ classes = []
+ area = []
+ boxes = []
+ keypoints = []
+ for obj in annotations:
+ if "iscrowd" not in obj or obj["iscrowd"] == 0:
+ classes.append(obj["category_id"])
+ area.append(obj["area"])
+ boxes.append(obj["bbox"])
+ if "keypoints" in obj:
+ keypoints.append(obj["keypoints"])
+
+ classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device)
+ area = torch.as_tensor(area, dtype=torch.float32, device=image.device)
+ iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device)
+ # guard against no boxes via resizing
+ boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4)
+ boxes[:, 2:] += boxes[:, :2]
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
+
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+
+ new_target = {
+ "image_id": image_id,
+ "class_labels": classes[keep],
+ "boxes": boxes[keep],
+ "area": area[keep],
+ "iscrowd": iscrowd[keep],
+ "orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device),
+ }
+
+ if keypoints:
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device)
+ # Apply the keep mask here to filter the relevant annotations
+ keypoints = keypoints[keep]
+ num_keypoints = keypoints.shape[0]
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
+ new_target["keypoints"] = keypoints
+
+ return new_target
+
+
+class RTDetrFastImageProcessorKwargs(DetrFastImageProcessorKwargs):
+ pass
+
+
+class RTDetrImageProcessorFast(DetrImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_DEFAULT_MEAN
+ image_std = IMAGENET_DEFAULT_STD
+ format = AnnotationFormat.COCO_DETECTION
+ do_convert_annotations = True
+ do_resize = True
+ do_rescale = True
+ do_normalize = False
+ do_pad = False
+ size = {"height": 640, "width": 640}
+ default_to_square = False
+ model_input_names = ["pixel_values", "pixel_mask"]
+ valid_kwargs = RTDetrFastImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[RTDetrFastImageProcessorKwargs]) -> None:
+ # Backwards compatibility
+ do_convert_annotations = kwargs.get("do_convert_annotations")
+ do_normalize = kwargs.get("do_normalize")
+ if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
+ self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
+
+ BaseImageProcessorFast.__init__(self, **kwargs)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
+ **kwargs: Unpack[RTDetrFastImageProcessorKwargs],
+ ) -> BatchFeature:
+ return BaseImageProcessorFast.preprocess(self, images, annotations, masks_path, **kwargs)
+
+ def prepare_annotation(
+ self,
+ image: torch.Tensor,
+ target: dict,
+ format: Optional[AnnotationFormat] = None,
+ return_segmentation_masks: Optional[bool] = None,
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> dict:
+ format = format if format is not None else self.format
+
+ if format == AnnotationFormat.COCO_DETECTION:
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
+ target = prepare_coco_detection_annotation(
+ image, target, return_segmentation_masks, input_data_format=input_data_format
+ )
+ else:
+ raise ValueError(f"Format {format} is not supported.")
+ return target
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]],
+ masks_path: Optional[Union[str, pathlib.Path]],
+ return_segmentation_masks: bool,
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ do_convert_annotations: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ do_pad: bool,
+ pad_size: Optional[SizeDict],
+ format: Optional[Union[str, AnnotationFormat]],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Preprocess an image or a batch of images so that it can be used by the model.
+ """
+
+ if annotations is not None and isinstance(annotations, dict):
+ annotations = [annotations]
+
+ if annotations is not None and len(images) != len(annotations):
+ raise ValueError(
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
+ )
+
+ format = AnnotationFormat(format)
+ if annotations is not None:
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
+
+ data = {}
+ processed_images = []
+ processed_annotations = []
+ pixel_masks = [] # Initialize pixel_masks here
+ for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
+ if annotations is not None:
+ annotation = self.prepare_annotation(
+ image,
+ annotation,
+ format,
+ return_segmentation_masks=return_segmentation_masks,
+ masks_path=masks_path,
+ input_data_format=ChannelDimension.FIRST,
+ )
+
+ if do_resize:
+ resized_image = self.resize(image, size=size, interpolation=interpolation)
+ if annotations is not None:
+ annotation = self.resize_annotation(
+ annotation,
+ orig_size=image.size()[-2:],
+ target_size=resized_image.size()[-2:],
+ )
+ image = resized_image
+ # Fused rescale and normalize
+ image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
+ if do_convert_annotations and annotations is not None:
+ annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
+
+ processed_images.append(image)
+ processed_annotations.append(annotation)
+ images = processed_images
+ annotations = processed_annotations if annotations is not None else None
+
+ if do_pad:
+ # depends on all resized image shapes so we need another loop
+ if pad_size is not None:
+ padded_size = (pad_size.height, pad_size.width)
+ else:
+ padded_size = get_max_height_width(images)
+
+ padded_images = []
+ padded_annotations = []
+ for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
+ if padded_size == image.size()[-2:]:
+ padded_images.append(image)
+ pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device))
+ padded_annotations.append(annotation)
+ continue
+ image, pixel_mask, annotation = self.pad(
+ image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations
+ )
+ padded_images.append(image)
+ padded_annotations.append(annotation)
+ pixel_masks.append(pixel_mask)
+ images = padded_images
+ annotations = padded_annotations if annotations is not None else None
+ data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)})
+
+ data.update({"pixel_values": torch.stack(images, dim=0)})
+ encoded_inputs = BatchFeature(data, tensor_type=return_tensors)
+ if annotations is not None:
+ encoded_inputs["labels"] = [
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
+ ]
+ return encoded_inputs
+
+ def post_process_object_detection(
+ self,
+ outputs,
+ threshold: float = 0.5,
+ target_sizes: Union[TensorType, list[tuple]] = None,
+ use_focal_loss: bool = True,
+ ):
+ """
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
+
+ Args:
+ outputs ([`DetrObjectDetectionOutput`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*, defaults to 0.5):
+ Score threshold to keep object detection predictions.
+ target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
+ use_focal_loss (`bool` defaults to `True`):
+ Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied
+ to compute the scores of each detection, otherwise, a softmax function is used.
+
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model.
+ """
+ requires_backends(self, ["torch"])
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
+ # convert from relative cxcywh to absolute xyxy
+ boxes = center_to_corners_format(out_bbox)
+ if target_sizes is not None:
+ if len(out_logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+ if isinstance(target_sizes, list):
+ img_h, img_w = torch.as_tensor(target_sizes).unbind(1)
+ else:
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
+ boxes = boxes * scale_fct[:, None, :]
+
+ num_top_queries = out_logits.shape[1]
+ num_classes = out_logits.shape[2]
+
+ if use_focal_loss:
+ scores = torch.nn.functional.sigmoid(out_logits)
+ scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
+ labels = index % num_classes
+ index = index // num_classes
+ boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
+ else:
+ scores = torch.nn.functional.softmax(out_logits)[:, :, :-1]
+ scores, labels = scores.max(dim=-1)
+ if scores.shape[1] > num_top_queries:
+ scores, index = torch.topk(scores, num_top_queries, dim=-1)
+ labels = torch.gather(labels, dim=1, index=index)
+ boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
+
+ results = []
+ for score, label, box in zip(scores, labels, boxes):
+ results.append(
+ {
+ "scores": score[score > threshold],
+ "labels": label[score > threshold],
+ "boxes": box[score > threshold],
+ }
+ )
+
+ return results
+
+ def from_dict(self):
+ raise NotImplementedError("No need to override this method for RT-DETR yet.")
+
+ def post_process(self):
+ raise NotImplementedError("Post-processing is not implemented for RT-DETR yet.")
+
+ def post_process_segmentation(self):
+ raise NotImplementedError("Segmentation post-processing is not implemented for RT-DETR yet.")
+
+ def post_process_instance(self):
+ raise NotImplementedError("Instance post-processing is not implemented for RT-DETR yet.")
+
+ def post_process_panoptic(self):
+ raise NotImplementedError("Panoptic post-processing is not implemented for RT-DETR yet.")
+
+ def post_process_instance_segmentation(self):
+ raise NotImplementedError("Segmentation post-processing is not implemented for RT-DETR yet.")
+
+ def post_process_semantic_segmentation(self):
+ raise NotImplementedError("Semantic segmentation post-processing is not implemented for RT-DETR yet.")
+
+ def post_process_panoptic_segmentation(self):
+ raise NotImplementedError("Panoptic segmentation post-processing is not implemented for RT-DETR yet.")
+
+
+__all__ = ["RTDetrImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb8a2b98e636b7653747861461b7d648b91164c7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_sam import *
+ from .image_processing_sam import *
+ from .image_processing_sam_fast import *
+ from .modeling_sam import *
+ from .modeling_tf_sam import *
+ from .processing_sam import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/configuration_sam.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/configuration_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..11a3e421d42eab2d98b67e6b0be00dbcc0bbc469
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/configuration_sam.py
@@ -0,0 +1,337 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""SAM model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class SamPromptEncoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`SamPromptEncoder`]. The [`SamPromptEncoder`]
+ module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield
+ a similar configuration to that of the SAM-vit-h
+ [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the hidden states.
+ image_size (`int`, *optional*, defaults to 1024):
+ The expected output resolution of the image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ mask_input_channels (`int`, *optional*, defaults to 16):
+ The number of channels to be fed to the `MaskDecoder` module.
+ num_point_embeddings (`int`, *optional*, defaults to 4):
+ The number of point embeddings to be used.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the encoder and pooler.
+ """
+
+ base_config_key = "prompt_encoder_config"
+
+ def __init__(
+ self,
+ hidden_size=256,
+ image_size=1024,
+ patch_size=16,
+ mask_input_channels=16,
+ num_point_embeddings=4,
+ hidden_act="gelu",
+ layer_norm_eps=1e-6,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.image_embedding_size = image_size // patch_size
+ self.mask_input_channels = mask_input_channels
+ self.num_point_embeddings = num_point_embeddings
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+
+
+class SamMaskDecoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`SamMaskDecoder`]. It is used to instantiate a SAM
+ mask decoder to the specified arguments, defining the model architecture. Instantiating a configuration defaults
+ will yield a similar configuration to that of the SAM-vit-h
+ [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the hidden states.
+ hidden_act (`str`, *optional*, defaults to `"relu"`):
+ The non-linear activation function used inside the `SamMaskDecoder` module.
+ mlp_dim (`int`, *optional*, defaults to 2048):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 2):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ attention_downsample_rate (`int`, *optional*, defaults to 2):
+ The downsampling rate of the attention layer.
+ num_multimask_outputs (`int`, *optional*, defaults to 3):
+ The number of outputs from the `SamMaskDecoder` module. In the Segment Anything paper, this is set to 3.
+ iou_head_depth (`int`, *optional*, defaults to 3):
+ The number of layers in the IoU head module.
+ iou_head_hidden_dim (`int`, *optional*, defaults to 256):
+ The dimensionality of the hidden states in the IoU head module.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+
+ """
+
+ base_config_key = "mask_decoder_config"
+
+ def __init__(
+ self,
+ hidden_size=256,
+ hidden_act="relu",
+ mlp_dim=2048,
+ num_hidden_layers=2,
+ num_attention_heads=8,
+ attention_downsample_rate=2,
+ num_multimask_outputs=3,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ layer_norm_eps=1e-6,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.mlp_dim = mlp_dim
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.attention_downsample_rate = attention_downsample_rate
+ self.num_multimask_outputs = num_multimask_outputs
+ self.iou_head_depth = iou_head_depth
+ self.iou_head_hidden_dim = iou_head_hidden_dim
+ self.layer_norm_eps = layer_norm_eps
+
+
+class SamVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`SamVisionModel`]. It is used to instantiate a SAM
+ vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
+ defaults will yield a similar configuration to that of the SAM ViT-h
+ [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ output_channels (`int`, *optional*, defaults to 256):
+ Dimensionality of the output channels in the Patch Encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input image.
+ image_size (`int`, *optional*, defaults to 1024):
+ Expected resolution. Target size of the resized input image.
+ patch_size (`int`, *optional*, defaults to 16):
+ Size of the patches to be extracted from the input image.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string)
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 1e-10):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to query, key, value projections.
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ Ratio of mlp hidden dim to embedding dim.
+ use_abs_pos (`bool`, *optional*, defaults to `True`):
+ Whether to use absolute position embedding.
+ use_rel_pos (`bool`, *optional*, defaults to `True`):
+ Whether to use relative position embedding.
+ window_size (`int`, *optional*, defaults to 14):
+ Window size for relative position.
+ global_attn_indexes (`list[int]`, *optional*, defaults to `[2, 5, 8, 11]`):
+ The indexes of the global attention layers.
+ num_pos_feats (`int`, *optional*, defaults to 128):
+ The dimensionality of the position embedding.
+ mlp_dim (`int`, *optional*):
+ The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio *
+ hidden_size`.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... SamVisionConfig,
+ ... SamVisionModel,
+ ... )
+
+ >>> # Initializing a SamVisionConfig with `"facebook/sam-vit-huge"` style configuration
+ >>> configuration = SamVisionConfig()
+
+ >>> # Initializing a SamVisionModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration
+ >>> model = SamVisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ base_config_key = "vision_config"
+ model_type = "sam_vision_model"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ output_channels=256,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ num_channels=3,
+ image_size=1024,
+ patch_size=16,
+ hidden_act="gelu",
+ layer_norm_eps=1e-06,
+ attention_dropout=0.0,
+ initializer_range=1e-10,
+ qkv_bias=True,
+ mlp_ratio=4.0,
+ use_abs_pos=True,
+ use_rel_pos=True,
+ window_size=14,
+ global_attn_indexes=[2, 5, 8, 11],
+ num_pos_feats=128,
+ mlp_dim=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.output_channels = output_channels
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.qkv_bias = qkv_bias
+ self.mlp_ratio = mlp_ratio
+ self.use_abs_pos = use_abs_pos
+ self.use_rel_pos = use_rel_pos
+ self.window_size = window_size
+ self.global_attn_indexes = global_attn_indexes
+ self.num_pos_feats = num_pos_feats
+ self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim
+
+
+class SamConfig(PretrainedConfig):
+ r"""
+ [`SamConfig`] is the configuration class to store the configuration of a [`SamModel`]. It is used to instantiate a
+ SAM model according to the specified arguments, defining the vision model, prompt-encoder model and mask decoder
+ configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the
+ SAM-ViT-H [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (Union[`dict`, `SamVisionConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`SamVisionConfig`].
+ prompt_encoder_config (Union[`dict`, `SamPromptEncoderConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`SamPromptEncoderConfig`].
+ mask_decoder_config (Union[`dict`, `SamMaskDecoderConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`SamMaskDecoderConfig`].
+
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... SamVisionConfig,
+ ... SamPromptEncoderConfig,
+ ... SamMaskDecoderConfig,
+ ... SamModel,
+ ... )
+
+ >>> # Initializing a SamConfig with `"facebook/sam-vit-huge"` style configuration
+ >>> configuration = SamConfig()
+
+ >>> # Initializing a SamModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration
+ >>> model = SamModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+
+ >>> # We can also initialize a SamConfig from a SamVisionConfig, SamPromptEncoderConfig, and SamMaskDecoderConfig
+
+ >>> # Initializing SAM vision, SAM Q-Former and language model configurations
+ >>> vision_config = SamVisionConfig()
+ >>> prompt_encoder_config = SamPromptEncoderConfig()
+ >>> mask_decoder_config = SamMaskDecoderConfig()
+
+ >>> config = SamConfig(vision_config, prompt_encoder_config, mask_decoder_config)
+ ```"""
+
+ model_type = "sam"
+ sub_configs = {
+ "prompt_encoder_config": SamPromptEncoderConfig,
+ "mask_decoder_config": SamMaskDecoderConfig,
+ "vision_config": SamVisionConfig,
+ }
+
+ def __init__(
+ self,
+ vision_config=None,
+ prompt_encoder_config=None,
+ mask_decoder_config=None,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ vision_config = vision_config if vision_config is not None else {}
+ prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {}
+ mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {}
+
+ if isinstance(vision_config, SamVisionConfig):
+ vision_config = vision_config.to_dict()
+ if isinstance(prompt_encoder_config, SamPromptEncoderConfig):
+ prompt_encoder_config = prompt_encoder_config.to_dict()
+ if isinstance(mask_decoder_config, SamMaskDecoderConfig):
+ mask_decoder_config = mask_decoder_config.to_dict()
+
+ self.vision_config = SamVisionConfig(**vision_config)
+ self.prompt_encoder_config = SamPromptEncoderConfig(**prompt_encoder_config)
+ self.mask_decoder_config = SamMaskDecoderConfig(**mask_decoder_config)
+ self.initializer_range = initializer_range
+
+
+__all__ = ["SamConfig", "SamMaskDecoderConfig", "SamPromptEncoderConfig", "SamVisionConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/image_processing_sam.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/image_processing_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..6acb775b43db458765892b48fe8fa372513aaa81
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/image_processing_sam.py
@@ -0,0 +1,1497 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for SAM."""
+
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import (
+ TensorType,
+ filter_out_non_signature_kwargs,
+ is_tf_available,
+ is_torch_available,
+ is_torchvision_available,
+ logging,
+ requires_backends,
+)
+
+
+if is_torch_available():
+ import torch
+ import torch.nn.functional as F
+
+if is_torchvision_available():
+ from torchvision.ops.boxes import batched_nms
+
+if is_tf_available():
+ import tensorflow as tf
+ from tensorflow.experimental import numpy as tnp
+
+ from ...tf_utils import flatten, shape_list
+
+logger = logging.get_logger(__name__)
+
+
+class SamImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a SAM image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"longest_edge": 1024}`):
+ Size of the output image after resizing. Resizes the longest edge of the image to match
+ `size["longest_edge"]` while maintaining the aspect ratio. Can be overridden by the `size` parameter in the
+ `preprocess` method.
+ mask_size (`dict`, *optional*, defaults to `{"longest_edge": 256}`):
+ Size of the output segmentation map after resizing. Resizes the longest edge of the image to match
+ `size["longest_edge"]` while maintaining the aspect ratio. Can be overridden by the `mask_size` parameter
+ in the `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+ `do_rescale` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image to the specified `pad_size`. Can be overridden by the `do_pad` parameter in the
+ `preprocess` method.
+ pad_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`):
+ Size of the output image after padding. Can be overridden by the `pad_size` parameter in the `preprocess`
+ method.
+ mask_pad_size (`dict`, *optional*, defaults to `{"height": 256, "width": 256}`):
+ Size of the output segmentation map after padding. Can be overridden by the `mask_pad_size` parameter in
+ the `preprocess` method.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ mask_size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: bool = True,
+ pad_size: Optional[int] = None,
+ mask_pad_size: Optional[int] = None,
+ do_convert_rgb: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"longest_edge": 1024}
+ size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
+
+ pad_size = pad_size if pad_size is not None else {"height": 1024, "width": 1024}
+ pad_size = get_size_dict(pad_size, default_to_square=True)
+
+ mask_size = mask_size if mask_size is not None else {"longest_edge": 256}
+ mask_size = (
+ get_size_dict(max_size=mask_size, default_to_square=False)
+ if not isinstance(mask_size, dict)
+ else mask_size
+ )
+
+ mask_pad_size = mask_pad_size if mask_pad_size is not None else {"height": 256, "width": 256}
+ mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.mask_size = mask_size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
+ self.do_pad = do_pad
+ self.pad_size = pad_size
+ self.mask_pad_size = mask_pad_size
+ self.do_convert_rgb = do_convert_rgb
+
+ def pad_image(
+ self,
+ image: np.ndarray,
+ pad_size: dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Pad an image to `(pad_size["height"], pad_size["width"])` with zeros to the right and bottom.
+
+ Args:
+ image (`np.ndarray`):
+ Image to pad.
+ pad_size (`dict[str, int]`):
+ Size of the output image after padding.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
+ `data_format` of the `image` will be used.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ output_height, output_width = pad_size["height"], pad_size["width"]
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+
+ pad_width = output_width - input_width
+ pad_height = output_height - input_height
+
+ padded_image = pad(
+ image,
+ ((0, pad_height), (0, pad_width)),
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+ return padded_image
+
+ def _get_preprocess_shape(self, old_shape: tuple[int, int], longest_edge: int):
+ """
+ Compute the output size given input size and target long side length.
+ """
+ oldh, oldw = old_shape
+ scale = longest_edge * 1.0 / max(oldh, oldw)
+ newh, neww = oldh * scale, oldw * scale
+ newh = int(newh + 0.5)
+ neww = int(neww + 0.5)
+ return (newh, neww)
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest
+ edge of the image will be resized to the specified size, while the other edge will be resized to
+ maintain the aspect ratio.
+ resample:
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ size = get_size_dict(size)
+ if "longest_edge" not in size:
+ raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}")
+ input_size = get_image_size(image, channel_dim=input_data_format)
+ output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"])
+ return resize(
+ image,
+ size=(output_height, output_width),
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def _preprocess(
+ self,
+ image: ImageInput,
+ do_resize: bool,
+ do_rescale: bool,
+ do_normalize: bool,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ rescale_factor: Optional[float] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ pad_size: Optional[dict[str, int]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+ reshaped_input_size = get_image_size(image, channel_dim=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+
+ if do_pad:
+ image = self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format)
+
+ return image, reshaped_input_size
+
+ def _preprocess_image(
+ self,
+ image: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ pad_size: Optional[dict[str, int]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
+ # PIL RGBA images are converted to RGB
+ if do_convert_rgb:
+ image = convert_to_rgb(image)
+
+ # All transformations expect numpy arrays.
+ image = to_numpy_array(image)
+
+ if do_rescale and is_scaled_image(image):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+
+ original_size = get_image_size(image, channel_dim=input_data_format)
+
+ image, reshaped_input_size = self._preprocess(
+ image=image,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_pad=do_pad,
+ pad_size=pad_size,
+ input_data_format=input_data_format,
+ )
+
+ if data_format is not None:
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+
+ return image, original_size, reshaped_input_size
+
+ def _preprocess_mask(
+ self,
+ segmentation_map: ImageInput,
+ do_resize: Optional[bool] = None,
+ mask_size: Optional[dict[str, int]] = None,
+ do_pad: Optional[bool] = None,
+ mask_pad_size: Optional[dict[str, int]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ segmentation_map = to_numpy_array(segmentation_map)
+
+ # Add channel dimension if missing - needed for certain transformations
+ if segmentation_map.ndim == 2:
+ added_channel_dim = True
+ segmentation_map = segmentation_map[None, ...]
+ input_data_format = ChannelDimension.FIRST
+ else:
+ added_channel_dim = False
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
+
+ original_size = get_image_size(segmentation_map, channel_dim=input_data_format)
+
+ segmentation_map, _ = self._preprocess(
+ image=segmentation_map,
+ do_resize=do_resize,
+ size=mask_size,
+ resample=PILImageResampling.NEAREST,
+ do_rescale=False,
+ do_normalize=False,
+ do_pad=do_pad,
+ pad_size=mask_pad_size,
+ input_data_format=input_data_format,
+ )
+
+ # Remove extra channel dimension if added for processing
+ if added_channel_dim:
+ segmentation_map = segmentation_map.squeeze(0)
+ segmentation_map = segmentation_map.astype(np.int64)
+
+ return segmentation_map, original_size
+
+ def __call__(self, images, segmentation_maps=None, **kwargs):
+ # Overrides the `__call__` method of the `BaseImageProcessor` class such that the images and segmentation maps can both
+ # be passed in as positional arguments.
+ return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ mask_size: Optional[dict[str, int]] = None,
+ resample: Optional["PILImageResampling"] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[Union[int, float]] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ pad_size: Optional[dict[str, int]] = None,
+ mask_pad_size: Optional[dict[str, int]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ segmentation_maps (`ImageInput`, *optional*):
+ Segmentation map to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Controls the size of the image after `resize`. The longest edge of the image is resized to
+ `size["longest_edge"]` whilst preserving the aspect ratio.
+ mask_size (`dict[str, int]`, *optional*, defaults to `self.mask_size`):
+ Controls the size of the segmentation map after `resize`. The longest edge of the image is resized to
+ `size["longest_edge"]` whilst preserving the aspect ratio.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image pixel values by rescaling factor.
+ rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to apply to the image pixel values.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to normalize the image by if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image.
+ pad_size (`dict[str, int]`, *optional*, defaults to `self.pad_size`):
+ Controls the size of the padding applied to the image. The image is padded to `pad_size["height"]` and
+ `pad_size["width"]` if `do_pad` is set to `True`.
+ mask_pad_size (`dict[str, int]`, *optional*, defaults to `self.mask_pad_size`):
+ Controls the size of the padding applied to the segmentation map. The image is padded to
+ `mask_pad_size["height"]` and `mask_pad_size["width"]` if `do_pad` is set to `True`.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
+ mask_size = mask_size if mask_size is not None else self.mask_size
+ mask_size = (
+ get_size_dict(max_size=mask_size, default_to_square=False)
+ if not isinstance(mask_size, dict)
+ else mask_size
+ )
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ pad_size = pad_size if pad_size is not None else self.pad_size
+ pad_size = get_size_dict(pad_size, default_to_square=True)
+ mask_pad_size = mask_pad_size if mask_pad_size is not None else self.mask_pad_size
+ mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True)
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if segmentation_maps is not None:
+ segmentation_maps = make_flat_list_of_images(segmentation_maps, expected_ndims=2)
+
+ if not valid_images(segmentation_maps):
+ raise ValueError(
+ "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ images, original_sizes, reshaped_input_sizes = zip(
+ *(
+ self._preprocess_image(
+ image=img,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_pad=do_pad,
+ pad_size=pad_size,
+ do_convert_rgb=do_convert_rgb,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ for img in images
+ )
+ )
+
+ data = {
+ "pixel_values": images,
+ "original_sizes": original_sizes,
+ "reshaped_input_sizes": reshaped_input_sizes,
+ }
+
+ if segmentation_maps is not None:
+ segmentation_maps, original_mask_sizes = zip(
+ *(
+ self._preprocess_mask(
+ segmentation_map=mask,
+ do_resize=do_resize,
+ mask_size=mask_size,
+ do_pad=do_pad,
+ mask_pad_size=mask_pad_size,
+ input_data_format=input_data_format,
+ )
+ for mask in segmentation_maps
+ )
+ )
+
+ # masks should start out the same size as input images
+ assert all(
+ original_im_size == original_mask_size
+ for original_im_size, original_mask_size in zip(original_sizes, original_mask_sizes)
+ ), "Segmentation maps should be the same size as input images."
+
+ data["labels"] = segmentation_maps
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def post_process_masks(
+ self,
+ masks,
+ original_sizes,
+ reshaped_input_sizes,
+ mask_threshold=0.0,
+ binarize=True,
+ pad_size=None,
+ return_tensors="pt",
+ ):
+ """
+ Remove padding and upscale masks to the original image size.
+
+ Args:
+ masks (`Union[list[torch.Tensor], list[np.ndarray], list[tf.Tensor]]`):
+ Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
+ original_sizes (`Union[torch.Tensor, tf.Tensor, list[tuple[int,int]]]`):
+ The original sizes of each image before it was resized to the model's expected input shape, in (height,
+ width) format.
+ reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, list[tuple[int,int]]]`):
+ The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
+ mask_threshold (`float`, *optional*, defaults to 0.0):
+ The threshold to use for binarizing the masks.
+ binarize (`bool`, *optional*, defaults to `True`):
+ Whether to binarize the masks.
+ pad_size (`int`, *optional*, defaults to `self.pad_size`):
+ The target size the images were padded to before being passed to the model. If None, the target size is
+ assumed to be the processor's `pad_size`.
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
+ If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors.
+ Returns:
+ (`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where
+ (height, width) is given by original_size.
+ """
+ if return_tensors == "pt":
+ return self._post_process_masks_pt(
+ masks=masks,
+ original_sizes=original_sizes,
+ reshaped_input_sizes=reshaped_input_sizes,
+ mask_threshold=mask_threshold,
+ binarize=binarize,
+ pad_size=pad_size,
+ )
+ elif return_tensors == "tf":
+ return self._post_process_masks_tf(
+ masks=masks,
+ original_sizes=original_sizes,
+ reshaped_input_sizes=reshaped_input_sizes,
+ mask_threshold=mask_threshold,
+ binarize=binarize,
+ pad_size=pad_size,
+ )
+ else:
+ raise ValueError("return_tensors must be either 'pt' or 'tf'")
+
+ def _post_process_masks_pt(
+ self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
+ ):
+ """
+ Remove padding and upscale masks to the original image size.
+
+ Args:
+ masks (`Union[list[torch.Tensor], list[np.ndarray]]`):
+ Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
+ original_sizes (`Union[torch.Tensor, list[tuple[int,int]]]`):
+ The original sizes of each image before it was resized to the model's expected input shape, in (height,
+ width) format.
+ reshaped_input_sizes (`Union[torch.Tensor, list[tuple[int,int]]]`):
+ The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
+ mask_threshold (`float`, *optional*, defaults to 0.0):
+ The threshold to use for binarizing the masks.
+ binarize (`bool`, *optional*, defaults to `True`):
+ Whether to binarize the masks.
+ pad_size (`int`, *optional*, defaults to `self.pad_size`):
+ The target size the images were padded to before being passed to the model. If None, the target size is
+ assumed to be the processor's `pad_size`.
+ Returns:
+ (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
+ is given by original_size.
+ """
+ requires_backends(self, ["torch"])
+ pad_size = self.pad_size if pad_size is None else pad_size
+ target_image_size = (pad_size["height"], pad_size["width"])
+ if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
+ original_sizes = original_sizes.tolist()
+ if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
+ reshaped_input_sizes = reshaped_input_sizes.tolist()
+ output_masks = []
+ for i, original_size in enumerate(original_sizes):
+ if isinstance(masks[i], np.ndarray):
+ masks[i] = torch.from_numpy(masks[i])
+ elif not isinstance(masks[i], torch.Tensor):
+ raise TypeError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
+ interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
+ interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
+ interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
+ if binarize:
+ interpolated_mask = interpolated_mask > mask_threshold
+ output_masks.append(interpolated_mask)
+
+ return output_masks
+
+ def _post_process_masks_tf(
+ self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
+ ):
+ """
+ Remove padding and upscale masks to the original image size.
+
+ Args:
+ masks (`tf.Tensor`):
+ Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
+ original_sizes (`tf.Tensor`):
+ The original size of the images before resizing for input to the model, in (height, width) format.
+ reshaped_input_sizes (`tf.Tensor`):
+ The size of the image input to the model, in (height, width) format. Used to remove padding.
+ mask_threshold (`float`, *optional*, defaults to 0.0):
+ The threshold to use for binarizing the masks.
+ binarize (`bool`, *optional*, defaults to `True`):
+ Whether to binarize the masks.
+ pad_size (`int`, *optional*, defaults to `self.pad_size`):
+ The target size the images were padded to before being passed to the model. If None, the target size is
+ assumed to be the processor's `pad_size`.
+ Returns:
+ (`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is
+ given by original_size.
+ """
+ requires_backends(self, ["tf"])
+ pad_size = self.pad_size if pad_size is None else pad_size
+ target_image_size = (pad_size["height"], pad_size["width"])
+
+ output_masks = []
+ for i, original_size in enumerate(original_sizes):
+ # tf.image expects NHWC, we transpose the NCHW inputs for it
+ mask = tf.transpose(masks[i], perm=[0, 2, 3, 1])
+ interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear")
+ interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :]
+ interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear")
+ if binarize:
+ interpolated_mask = interpolated_mask > mask_threshold
+ # And then we transpose them back at the end
+ output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2]))
+
+ return output_masks
+
+ def post_process_for_mask_generation(
+ self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt"
+ ):
+ """
+ Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
+
+ Args:
+ all_masks (`Union[list[torch.Tensor], list[tf.Tensor]]`):
+ List of all predicted segmentation masks
+ all_scores (`Union[list[torch.Tensor], list[tf.Tensor]]`):
+ List of all predicted iou scores
+ all_boxes (`Union[list[torch.Tensor], list[tf.Tensor]]`):
+ List of all bounding boxes of the predicted masks
+ crops_nms_thresh (`float`):
+ Threshold for NMS (Non Maximum Suppression) algorithm.
+ return_tensors (`str`, *optional*, defaults to `pt`):
+ If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
+ """
+ if return_tensors == "pt":
+ return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
+ elif return_tensors == "tf":
+ return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh)
+
+ def generate_crop_boxes(
+ self,
+ image,
+ target_size,
+ crop_n_layers: int = 0,
+ overlap_ratio: float = 512 / 1500,
+ points_per_crop: Optional[int] = 32,
+ crop_n_points_downscale_factor: Optional[list[int]] = 1,
+ device: Optional["torch.device"] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ return_tensors: str = "pt",
+ ):
+ """
+ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
+
+ Args:
+ image (`np.ndarray`):
+ Input original image
+ target_size (`int`):
+ Target size of the resized image
+ crop_n_layers (`int`, *optional*, defaults to 0):
+ If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
+ each layer has 2**i_layer number of image crops.
+ overlap_ratio (`float`, *optional*, defaults to 512/1500):
+ Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
+ the image length. Later layers with more crops scale down this overlap.
+ points_per_crop (`int`, *optional*, defaults to 32):
+ Number of points to sample from each crop.
+ crop_n_points_downscale_factor (`list[int]`, *optional*, defaults to 1):
+ The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ device (`torch.device`, *optional*, defaults to None):
+ Device to use for the computation. If None, cpu will be used.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ return_tensors (`str`, *optional*, defaults to `pt`):
+ If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
+ """
+ crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
+ image,
+ target_size,
+ crop_n_layers,
+ overlap_ratio,
+ points_per_crop,
+ crop_n_points_downscale_factor,
+ input_data_format,
+ )
+ if return_tensors == "pt":
+ if device is None:
+ device = torch.device("cpu")
+ crop_boxes = torch.tensor(crop_boxes, device=device)
+ points_per_crop = torch.tensor(points_per_crop, device=device)
+ # cropped_images stays as np
+ input_labels = torch.tensor(input_labels, device=device)
+
+ elif return_tensors == "tf":
+ if device is not None:
+ raise ValueError("device is not a supported argument when return_tensors is tf!")
+ crop_boxes = tf.convert_to_tensor(crop_boxes)
+ points_per_crop = tf.convert_to_tensor(points_per_crop)
+ # cropped_images stays as np
+ input_labels = tf.convert_to_tensor(input_labels)
+ else:
+ raise ValueError("return_tensors must be either 'pt' or 'tf'.")
+ return crop_boxes, points_per_crop, cropped_images, input_labels
+
+ def filter_masks(
+ self,
+ masks,
+ iou_scores,
+ original_size,
+ cropped_box_image,
+ pred_iou_thresh=0.88,
+ stability_score_thresh=0.95,
+ mask_threshold=0,
+ stability_score_offset=1,
+ return_tensors="pt",
+ ):
+ """
+ Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
+ that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
+ score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
+ bounding boxes and pad the predicted masks if necessary.
+
+ Args:
+ masks (`Union[torch.Tensor, tf.Tensor]`):
+ Input masks.
+ iou_scores (`Union[torch.Tensor, tf.Tensor]`):
+ List of IoU scores.
+ original_size (`tuple[int,int]`):
+ Size of the original image.
+ cropped_box_image (`np.ndarray`):
+ The cropped image.
+ pred_iou_thresh (`float`, *optional*, defaults to 0.88):
+ The threshold for the iou scores.
+ stability_score_thresh (`float`, *optional*, defaults to 0.95):
+ The threshold for the stability score.
+ mask_threshold (`float`, *optional*, defaults to 0):
+ The threshold for the predicted masks.
+ stability_score_offset (`float`, *optional*, defaults to 1):
+ The offset for the stability score used in the `_compute_stability_score` method.
+ return_tensors (`str`, *optional*, defaults to `pt`):
+ If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
+ """
+ if return_tensors == "pt":
+ return self._filter_masks_pt(
+ masks=masks,
+ iou_scores=iou_scores,
+ original_size=original_size,
+ cropped_box_image=cropped_box_image,
+ pred_iou_thresh=pred_iou_thresh,
+ stability_score_thresh=stability_score_thresh,
+ mask_threshold=mask_threshold,
+ stability_score_offset=stability_score_offset,
+ )
+ elif return_tensors == "tf":
+ return self._filter_masks_tf(
+ masks=masks,
+ iou_scores=iou_scores,
+ original_size=original_size,
+ cropped_box_image=cropped_box_image,
+ pred_iou_thresh=pred_iou_thresh,
+ stability_score_thresh=stability_score_thresh,
+ mask_threshold=mask_threshold,
+ stability_score_offset=stability_score_offset,
+ )
+
+ def _filter_masks_pt(
+ self,
+ masks,
+ iou_scores,
+ original_size,
+ cropped_box_image,
+ pred_iou_thresh=0.88,
+ stability_score_thresh=0.95,
+ mask_threshold=0,
+ stability_score_offset=1,
+ ):
+ """
+ Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
+ that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
+ score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
+ bounding boxes and pad the predicted masks if necessary.
+
+ Args:
+ masks (`torch.Tensor`):
+ Input masks.
+ iou_scores (`torch.Tensor`):
+ List of IoU scores.
+ original_size (`tuple[int,int]`):
+ Size of the original image.
+ cropped_box_image (`np.ndarray`):
+ The cropped image.
+ pred_iou_thresh (`float`, *optional*, defaults to 0.88):
+ The threshold for the iou scores.
+ stability_score_thresh (`float`, *optional*, defaults to 0.95):
+ The threshold for the stability score.
+ mask_threshold (`float`, *optional*, defaults to 0):
+ The threshold for the predicted masks.
+ stability_score_offset (`float`, *optional*, defaults to 1):
+ The offset for the stability score used in the `_compute_stability_score` method.
+
+ """
+ requires_backends(self, ["torch"])
+ original_height, original_width = original_size
+ iou_scores = iou_scores.flatten(0, 1)
+ masks = masks.flatten(0, 1)
+
+ if masks.shape[0] != iou_scores.shape[0]:
+ raise ValueError("masks and iou_scores must have the same batch size.")
+
+ if masks.device != iou_scores.device:
+ iou_scores = iou_scores.to(masks.device)
+
+ batch_size = masks.shape[0]
+
+ keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
+
+ if pred_iou_thresh > 0.0:
+ keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
+
+ # compute stability score
+ if stability_score_thresh > 0.0:
+ stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset)
+ keep_mask = keep_mask & (stability_scores > stability_score_thresh)
+
+ scores = iou_scores[keep_mask]
+ masks = masks[keep_mask]
+
+ # binarize masks
+ masks = masks > mask_threshold
+ converted_boxes = _batched_mask_to_box(masks)
+
+ keep_mask = ~_is_box_near_crop_edge(
+ converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
+ )
+
+ scores = scores[keep_mask]
+ masks = masks[keep_mask]
+ converted_boxes = converted_boxes[keep_mask]
+
+ masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
+ # conversion to rle is necessary to run non-maximum suppression
+ masks = _mask_to_rle_pytorch(masks)
+
+ return masks, scores, converted_boxes
+
+ def _filter_masks_tf(
+ self,
+ masks,
+ iou_scores,
+ original_size,
+ cropped_box_image,
+ pred_iou_thresh=0.88,
+ stability_score_thresh=0.95,
+ mask_threshold=0,
+ stability_score_offset=1,
+ ):
+ """
+ Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
+ that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
+ score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
+ bounding boxes and pad the predicted masks if necessary.
+
+ Args:
+ masks (`tf.Tensor`):
+ Input masks.
+ iou_scores (`tf.Tensor`):
+ List of IoU scores.
+ original_size (`tuple[int,int]`):
+ Size of the original image.
+ cropped_box_image (`np.array`):
+ The cropped image.
+ pred_iou_thresh (`float`, *optional*, defaults to 0.88):
+ The threshold for the iou scores.
+ stability_score_thresh (`float`, *optional*, defaults to 0.95):
+ The threshold for the stability score.
+ mask_threshold (`float`, *optional*, defaults to 0):
+ The threshold for the predicted masks.
+ stability_score_offset (`float`, *optional*, defaults to 1):
+ The offset for the stability score used in the `_compute_stability_score` method.
+
+ """
+ requires_backends(self, ["tf"])
+ original_height, original_width = original_size
+ iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]])
+ masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]])
+
+ if masks.shape[0] != iou_scores.shape[0]:
+ raise ValueError("masks and iou_scores must have the same batch size.")
+
+ batch_size = masks.shape[0]
+
+ keep_mask = tf.ones(batch_size, dtype=tf.bool)
+
+ if pred_iou_thresh > 0.0:
+ keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
+
+ # compute stability score
+ if stability_score_thresh > 0.0:
+ stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset)
+ keep_mask = keep_mask & (stability_scores > stability_score_thresh)
+
+ scores = iou_scores[keep_mask]
+ masks = masks[keep_mask]
+
+ # binarize masks
+ masks = masks > mask_threshold
+ converted_boxes = _batched_mask_to_box_tf(masks)
+
+ keep_mask = ~_is_box_near_crop_edge_tf(
+ converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
+ )
+
+ scores = scores[keep_mask]
+ masks = masks[keep_mask]
+ converted_boxes = converted_boxes[keep_mask]
+
+ masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width)
+ # conversion to rle is necessary to run non-maximum suppression
+ masks = _mask_to_rle_tf(masks)
+
+ return masks, scores, converted_boxes
+
+
+def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
+ # One mask is always contained inside the other.
+ # Save memory by preventing unnecessary cast to torch.int64
+ intersections = (
+ (masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
+ )
+ unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
+ stability_scores = intersections / unions
+ return stability_scores
+
+
+def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int):
+ # Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure
+ # we get the right division results.
+ intersections = tf.count_nonzero(
+ masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32
+ )
+ unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32)
+ stability_scores = intersections / unions
+ return stability_scores
+
+
+def _build_point_grid(n_per_side: int) -> np.ndarray:
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+ offset = 1 / (2 * n_per_side)
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+ return points
+
+
+def _normalize_coordinates(
+ target_size: int, coords: np.ndarray, original_size: tuple[int, int], is_bounding_box=False
+) -> np.ndarray:
+ """
+ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width)
+ format.
+ """
+ old_height, old_width = original_size
+
+ scale = target_size * 1.0 / max(old_height, old_width)
+ new_height, new_width = old_height * scale, old_width * scale
+ new_width = int(new_width + 0.5)
+ new_height = int(new_height + 0.5)
+
+ coords = deepcopy(coords).astype(float)
+
+ if is_bounding_box:
+ coords = coords.reshape(-1, 2, 2)
+
+ coords[..., 0] = coords[..., 0] * (new_width / old_width)
+ coords[..., 1] = coords[..., 1] * (new_height / old_height)
+
+ if is_bounding_box:
+ coords = coords.reshape(-1, 4)
+
+ return coords
+
+
+def _generate_crop_boxes(
+ image,
+ target_size: int, # Is it tuple here?
+ crop_n_layers: int = 0,
+ overlap_ratio: float = 512 / 1500,
+ points_per_crop: Optional[int] = 32,
+ crop_n_points_downscale_factor: Optional[list[int]] = 1,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> tuple[list[list[int]], list[int]]:
+ """
+ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
+
+ Args:
+ image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]):
+ Image to generate crops for.
+ target_size (`int`):
+ Size of the smallest crop.
+ crop_n_layers (`int`, *optional*):
+ If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers
+ to run, where each layer has 2**i_layer number of image crops.
+ overlap_ratio (`int`, *optional*):
+ Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the
+ image length. Later layers with more crops scale down this overlap.
+ points_per_crop (`int`, *optional*):
+ Number of points to sample per crop.
+ crop_n_points_downscale_factor (`int`, *optional*):
+ The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+
+ if isinstance(image, list):
+ raise TypeError("Only one image is allowed for crop generation.")
+ image = to_numpy_array(image)
+ original_size = get_image_size(image, input_data_format)
+
+ points_grid = []
+ for i in range(crop_n_layers + 1):
+ n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
+ points_grid.append(_build_point_grid(n_points))
+
+ crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
+
+ cropped_images, point_grid_per_crop = _generate_crop_images(
+ crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format
+ )
+ crop_boxes = np.array(crop_boxes)
+ crop_boxes = crop_boxes.astype(np.float32)
+ points_per_crop = np.array([point_grid_per_crop])
+ points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
+
+ input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64)
+
+ return crop_boxes, points_per_crop, cropped_images, input_labels
+
+
+def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
+ """
+ Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format
+ consists of the following required indices:
+ - X: X coordinate of the top left of the bounding box
+ - Y: Y coordinate of the top left of the bounding box
+ - W: width of the bounding box
+ - H: height of the bounding box
+ """
+ crop_boxes, layer_idxs = [], []
+ im_height, im_width = original_size
+ short_side = min(im_height, im_width)
+
+ # Original image
+ crop_boxes.append([0, 0, im_width, im_height])
+ layer_idxs.append(0)
+ for i_layer in range(crop_n_layers):
+ n_crops_per_side = 2 ** (i_layer + 1)
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+ crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
+ crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
+
+ crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
+ crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
+
+ for left, top in product(crop_box_x0, crop_box_y0):
+ box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
+ crop_boxes.append(box)
+ layer_idxs.append(i_layer + 1)
+
+ return crop_boxes, layer_idxs
+
+
+def _generate_crop_images(
+ crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
+):
+ """
+ Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
+ also passed.
+ """
+ cropped_images = []
+ total_points_per_crop = []
+ for i, crop_box in enumerate(crop_boxes):
+ left, top, right, bottom = crop_box
+
+ channel_dim = infer_channel_dimension_format(image, input_data_format)
+ if channel_dim == ChannelDimension.LAST:
+ cropped_im = image[top:bottom, left:right, :]
+ else:
+ cropped_im = image[:, top:bottom, left:right]
+
+ cropped_images.append(cropped_im)
+
+ cropped_im_size = get_image_size(cropped_im, channel_dim)
+ points_scale = np.array(cropped_im_size)[None, ::-1]
+
+ points = points_grid[layer_idxs[i]] * points_scale
+ normalized_points = _normalize_coordinates(target_size, points, original_size)
+ total_points_per_crop.append(normalized_points)
+
+ return cropped_images, total_points_per_crop
+
+
+def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int):
+ left, top, right, bottom = crop_box
+ if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
+ return masks
+ # Coordinate transform masks
+ pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
+ pad = (left, pad_x - left, top, pad_y - top)
+ return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def _pad_masks_tf(masks, crop_box: list[int], orig_height: int, orig_width: int):
+ left, top, right, bottom = crop_box
+ if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
+ return masks
+ # Coordinate transform masks
+ pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
+ pad = (left, pad_x - left, top, pad_y - top)
+ return tf.pad(masks, pad, constant_values=0)
+
+
+def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+
+ left, top, _, _ = crop_box
+ offset = torch.tensor([[left, top, left, top]], device=boxes.device)
+ # Check if boxes has a channel dimension
+ if len(boxes.shape) == 3:
+ offset = offset.unsqueeze(1)
+ boxes = (boxes + offset).float()
+
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+ return torch.any(near_crop_edge, dim=1)
+
+
+def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0):
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
+ crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32)
+ orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32)
+
+ left, top, _, _ = crop_box
+ offset = tf.convert_to_tensor([[left, top, left, top]])
+ # Check if boxes has a channel dimension
+ if len(boxes.shape) == 3:
+ offset = tf.expand_dims(offset, 1)
+ boxes = tf.cast(boxes + offset, tf.float32)
+
+ near_crop_edge = tnp.isclose(boxes, crop_box_tf[None, :], atol=atol, rtol=0)
+ near_image_edge = tnp.isclose(boxes, orig_box_tf[None, :], atol=atol, rtol=0)
+ near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge)
+ return tf.reduce_any(near_crop_edge, axis=1)
+
+
+def _batched_mask_to_box(masks: "torch.Tensor"):
+ """
+ Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
+ corresponds the following required indices:
+ - LEFT: left hand side of the bounding box
+ - TOP: top of the bounding box
+ - RIGHT: right of the bounding box
+ - BOTTOM: bottom of the bounding box
+
+ Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
+ is channel_1 x channel_2 x ... x 4.
+
+ Args:
+ - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
+ """
+ # torch.max below raises an error on empty inputs, just skip in this case
+
+ if torch.numel(masks) == 0:
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+ # Normalize shape to Cxheightxwidth
+ shape = masks.shape
+ height, width = shape[-2:]
+
+ # Get top and bottom edges
+ in_height, _ = torch.max(masks, dim=-1)
+ in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+ in_height_coords = in_height_coords + height * (~in_height)
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+ # Get left and right edges
+ in_width, _ = torch.max(masks, dim=-2)
+ in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
+ in_width_coords = in_width_coords + width * (~in_width)
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+ # If the mask is empty the right edge will be to the left of the left edge.
+ # Replace these boxes with [0, 0, 0, 0]
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+ out = out * (~empty_filter).unsqueeze(-1)
+
+ # Return to original shape
+ out = out.reshape(*shape[:-2], 4)
+ return out
+
+
+def _batched_mask_to_box_tf(masks: "tf.Tensor"):
+ """
+ Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
+ corresponds the following required indices:
+ - LEFT: left hand side of the bounding box
+ - TOP: top of the bounding box
+ - RIGHT: right of the bounding box
+ - BOTTOM: bottom of the bounding box
+
+ Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
+ is channel_1 x channel_2 x ... x 4.
+
+ Args:
+ - masks (`tf.Tensor` of shape `(batch, nb_mask, height, width)`)
+ """
+
+ if tf.size(masks) == 0:
+ return tf.zeros([*masks.shape[:-2], 4])
+
+ # Normalize shape to Cxheightxwidth
+ shape = shape_list(masks)
+ height, width = shape[-2:]
+
+ # Get top and bottom edges
+ in_height = tf.reduce_max(masks, axis=-1)
+ in_height_coords = in_height * tf.range(height)[None, :]
+ bottom_edges = tf.reduce_max(in_height_coords, axis=-1)
+ in_height_coords = in_height_coords + height * (~in_height)
+ top_edges = tf.reduce_min(in_height_coords, axis=-1)
+
+ # Get left and right edges
+ in_width, _ = tf.reduce_max(masks, axis=-2)
+ in_width_coords = in_width * tf.range(width)[None, :]
+ right_edges, _ = tf.reduce_max(in_width_coords, axis=-1)
+ in_width_coords = in_width_coords + width * (~in_width)
+ left_edges, _ = tf.reduce_min(in_width_coords, axis=-1)
+
+ # If the mask is empty the right edge will be to the left of the left edge.
+ # Replace these boxes with [0, 0, 0, 0]
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+ out = tf.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1)
+ out = out * tf.expand_dims(~empty_filter, -1)
+
+ # Return to original shape
+ out = tf.reshape(out, *shape[:-2], 4)
+ return out
+
+
+def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
+ """
+ Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
+ """
+ # Put in fortran order and flatten height and width
+ batch_size, height, width = input_mask.shape
+ input_mask = input_mask.permute(0, 2, 1).flatten(1)
+
+ # Compute change indices
+ diff = input_mask[:, 1:] ^ input_mask[:, :-1]
+ change_indices = diff.nonzero()
+
+ # Encode run length
+ out = []
+ for i in range(batch_size):
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
+ if len(cur_idxs) == 0:
+ # No changes => either all 0 or all 1
+ # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
+ if input_mask[i, 0] == 0:
+ out.append({"size": [height, width], "counts": [height * width]})
+ else:
+ out.append({"size": [height, width], "counts": [0, height * width]})
+ continue
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+ counts = [] if input_mask[i, 0] == 0 else [0]
+ counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
+ out.append({"size": [height, width], "counts": counts})
+ return out
+
+
+def _mask_to_rle_tf(input_mask: "tf.Tensor"):
+ """
+ Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
+ """
+ # Put in fortran order and flatten height and width
+ batch_size, height, width = input_mask.shape
+ input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1)
+
+ # Compute change indices
+ diff = input_mask[:, 1:] ^ input_mask[:, :-1]
+ change_indices = tf.where(diff)
+
+ # Encode run length
+ out = []
+ for i in range(batch_size):
+ cur_idxs = change_indices[change_indices[:, 0] == i][:, 1] + 1
+ if len(cur_idxs) == 0:
+ # No changes => either all 0 or all 1
+ # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
+ if input_mask[i, 0] == 0:
+ out.append({"size": [height, width], "counts": [height * width]})
+ else:
+ out.append({"size": [height, width], "counts": [0, height * width]})
+ continue
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+ counts = [] if input_mask[i, 0] == 0 else [0]
+ counts += (
+ [cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1].numpy().item()]
+ )
+ out.append({"size": [height, width], "counts": counts})
+ return out
+
+
+def _rle_to_mask(rle: dict[str, Any]) -> np.ndarray:
+ """Compute a binary mask from an uncompressed RLE."""
+ height, width = rle["size"]
+ mask = np.empty(height * width, dtype=bool)
+ idx = 0
+ parity = False
+ for count in rle["counts"]:
+ mask[idx : idx + count] = parity
+ idx += count
+ parity = not parity
+ mask = mask.reshape(width, height)
+ return mask.transpose() # Reshape to original shape
+
+
+def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
+ """
+ Perform NMS (Non Maximum Suppression) on the outputs.
+
+ Args:
+ rle_masks (`torch.Tensor`):
+ binary masks in the RLE format
+ iou_scores (`torch.Tensor` of shape (nb_masks, 1)):
+ iou_scores predicted by the model
+ mask_boxes (`torch.Tensor`):
+ The bounding boxes corresponding to segmentation masks
+ amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
+ NMS threshold.
+ """
+ keep_by_nms = batched_nms(
+ boxes=mask_boxes.float(),
+ scores=iou_scores,
+ idxs=torch.zeros(mask_boxes.shape[0]),
+ iou_threshold=amg_crops_nms_thresh,
+ )
+
+ iou_scores = iou_scores[keep_by_nms]
+ rle_masks = [rle_masks[i] for i in keep_by_nms]
+ mask_boxes = mask_boxes[keep_by_nms]
+ masks = [_rle_to_mask(rle) for rle in rle_masks]
+
+ return masks, iou_scores, rle_masks, mask_boxes
+
+
+def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
+ """
+ Perform NMS (Non Maximum Suppression) on the outputs.
+
+ Args:
+ rle_masks (`tf.Tensor`):
+ binary masks in the RLE format
+ iou_scores (`tf.Tensor` of shape (nb_masks, 1)):
+ iou_scores predicted by the model
+ mask_boxes (`tf.Tensor`):
+ The bounding boxes corresponding to segmentation masks
+ amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
+ NMS threshold.
+ """
+ keep_by_nms = tf.image.combined_non_max_suppression(
+ boxes=mask_boxes.float(),
+ scores=iou_scores,
+ idxs=torch.zeros(mask_boxes.shape[0]),
+ iou_threshold=amg_crops_nms_thresh,
+ )
+
+ iou_scores = iou_scores[keep_by_nms]
+ rle_masks = [rle_masks[i] for i in keep_by_nms]
+ mask_boxes = mask_boxes[keep_by_nms]
+ masks = [_rle_to_mask(rle) for rle in rle_masks]
+
+ return masks, iou_scores, rle_masks, mask_boxes
+
+
+__all__ = ["SamImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/image_processing_sam_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/image_processing_sam_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cfd5314899c4d9df1064ffbb11dcc0fbe9d990e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/image_processing_sam_fast.py
@@ -0,0 +1,749 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for SAM."""
+
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Optional, Union
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from torchvision.ops.boxes import batched_nms
+from torchvision.transforms.v2 import functional as F_t
+
+from ...image_processing_utils import BatchFeature, get_size_dict
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+)
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+ pil_torch_interpolation_mapping,
+)
+from ...processing_utils import Unpack
+from ...utils import auto_docstring
+
+
+class SamFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ r"""
+ mask_size (`dict[str, int]`, *optional*):
+ The size `{"longest_edge": int}` to resize the segmentation maps to.
+ mask_pad_size (`dict[str, int]`, *optional*):
+ The size `{"height": int, "width": int}` to pad the segmentation maps to. Must be larger than any segmentation
+ map size provided for preprocessing.
+ """
+
+ mask_size: Optional[dict[str, int]]
+ mask_pad_size: Optional[dict[str, int]]
+
+
+@auto_docstring
+class SamImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_DEFAULT_MEAN
+ image_std = IMAGENET_DEFAULT_STD
+ size = {"longest_edge": 1024}
+ mask_size = {"longest_edge": 256}
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+
+ valid_kwargs = SamFastImageProcessorKwargs
+
+ do_pad = True
+ pad_size = {"height": 1024, "width": 1024}
+ mask_pad_size = {"height": 256, "width": 256}
+
+ def __init__(self, **kwargs: Unpack[SamFastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ def _get_preprocess_shape(self, old_shape: tuple[int, int], longest_edge: int):
+ """
+ Compute the output size given input size and target long side length.
+ """
+ oldh, oldw = old_shape
+ scale = longest_edge * 1.0 / max(oldh, oldw)
+ newh, neww = oldh * scale, oldw * scale
+ newh = int(newh + 0.5)
+ neww = int(neww + 0.5)
+ return (newh, neww)
+
+ def resize(
+ self, image: "torch.Tensor", size: SizeDict, interpolation: Optional["F_t.InterpolationMode"], **kwargs
+ ) -> "torch.Tensor":
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest
+ edge of the image will be resized to the specified size, while the other edge will be resized to
+ maintain the aspect ratio.
+ interpolation:
+ `F_t.InterpolationMode` filter to use when resizing the image e.g. `F_t.InterpolationMode.BICUBIC`.
+
+ Returns:
+ `torch.Tensor`: The resized image.
+ """
+ if not size.longest_edge:
+ raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}")
+ input_size = image.shape[-2:]
+ output_height, output_width = self._get_preprocess_shape(input_size, size.longest_edge)
+ return super().resize(
+ image, size=SizeDict(height=output_height, width=output_width), interpolation=interpolation, **kwargs
+ )
+
+ def _further_process_kwargs(
+ self,
+ size: Optional[SizeDict] = None,
+ pad_size: Optional[SizeDict] = None,
+ mask_size: Optional[SizeDict] = None,
+ mask_pad_size: Optional[SizeDict] = None,
+ default_to_square: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ **kwargs,
+ ) -> dict:
+ """
+ Update kwargs that need further processing before being validated
+ Can be overridden by subclasses to customize the processing of kwargs.
+ """
+ if kwargs is None:
+ kwargs = {}
+ if size is not None:
+ size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
+ if pad_size is not None:
+ pad_size = SizeDict(**get_size_dict(pad_size, param_name="pad_size"))
+ if mask_size is not None:
+ mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size"))
+ if mask_pad_size is not None:
+ mask_pad_size = SizeDict(**get_size_dict(mask_pad_size, param_name="mask_pad_size"))
+ if isinstance(image_mean, list):
+ image_mean = tuple(image_mean)
+ if isinstance(image_std, list):
+ image_std = tuple(image_std)
+ if data_format is None:
+ data_format = ChannelDimension.FIRST
+
+ kwargs["size"] = size
+ kwargs["pad_size"] = pad_size
+ kwargs["mask_size"] = mask_size
+ kwargs["mask_pad_size"] = mask_pad_size
+ kwargs["image_mean"] = image_mean
+ kwargs["image_std"] = image_std
+ kwargs["data_format"] = data_format
+
+ # torch resize uses interpolation instead of resample
+ # Check if resample is an int before checking if it's an instance of PILImageResampling
+ # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
+ # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
+ resample = kwargs.pop("resample")
+ kwargs["interpolation"] = (
+ pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
+ )
+
+ return kwargs
+
+ @auto_docstring
+ def preprocess(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput] = None,
+ **kwargs: Unpack[SamFastImageProcessorKwargs],
+ ) -> BatchFeature:
+ r"""
+ segmentation_maps (`ImageInput`, *optional*):
+ The segmentation maps to preprocess.
+ """
+ return super().preprocess(images, segmentation_maps, **kwargs)
+
+ def _preprocess_image_like_inputs(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput],
+ do_convert_rgb: bool,
+ input_data_format: ChannelDimension,
+ device: Optional[Union[str, "torch.device"]] = None,
+ **kwargs: Unpack[SamFastImageProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Preprocess image-like inputs.
+ """
+ images = self._prepare_image_like_inputs(
+ images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
+ )
+ original_sizes = [image.shape[-2:] for image in images]
+ images_kwargs = kwargs.copy()
+ pixel_values = self._preprocess(images, **images_kwargs)["pixel_values"]
+ reshaped_input_sizes = [image.shape[-2:] for image in images]
+ data = {
+ "pixel_values": pixel_values,
+ "original_sizes": original_sizes,
+ "reshaped_input_sizes": reshaped_input_sizes,
+ }
+
+ if segmentation_maps is not None:
+ processed_segmentation_maps = self._prepare_image_like_inputs(
+ images=segmentation_maps,
+ expected_ndims=2,
+ do_convert_rgb=False,
+ input_data_format=ChannelDimension.FIRST,
+ )
+
+ segmentation_maps_kwargs = kwargs.copy()
+ segmentation_maps_kwargs.update(
+ {
+ "do_normalize": False,
+ "do_rescale": False,
+ "interpolation": F_t.InterpolationMode.NEAREST_EXACT,
+ "size": segmentation_maps_kwargs.pop("mask_size"),
+ "pad_size": segmentation_maps_kwargs.pop("mask_pad_size"),
+ }
+ )
+ processed_segmentation_maps = self._preprocess(
+ images=processed_segmentation_maps, **segmentation_maps_kwargs
+ )
+ data["labels"] = processed_segmentation_maps["pixel_values"].squeeze(1).to(torch.int64)
+
+ return BatchFeature(data=data, tensor_type=kwargs["return_tensors"])
+
+ def generate_crop_boxes(
+ self,
+ image: "torch.Tensor",
+ target_size,
+ crop_n_layers: int = 0,
+ overlap_ratio: float = 512 / 1500,
+ points_per_crop: Optional[int] = 32,
+ crop_n_points_downscale_factor: Optional[list[int]] = 1,
+ device: Optional["torch.device"] = None,
+ ):
+ """
+ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
+
+ Args:
+ image (`torch.Tensor`):
+ Input original image
+ target_size (`int`):
+ Target size of the resized image
+ crop_n_layers (`int`, *optional*, defaults to 0):
+ If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
+ each layer has 2**i_layer number of image crops.
+ overlap_ratio (`float`, *optional*, defaults to 512/1500):
+ Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
+ the image length. Later layers with more crops scale down this overlap.
+ points_per_crop (`int`, *optional*, defaults to 32):
+ Number of points to sample from each crop.
+ crop_n_points_downscale_factor (`list[int]`, *optional*, defaults to 1):
+ The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ device (`torch.device`, *optional*, defaults to None):
+ Device to use for the computation. If None, cpu will be used.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ return_tensors (`str`, *optional*, defaults to `pt`):
+ If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
+ """
+ image = self._process_image(image)
+ crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
+ image,
+ target_size,
+ crop_n_layers,
+ overlap_ratio,
+ points_per_crop,
+ crop_n_points_downscale_factor,
+ )
+ if device is None:
+ device = torch.device("cpu")
+ crop_boxes = crop_boxes.to(device)
+ points_per_crop = points_per_crop.to(device)
+ # cropped_images stays as torch.Tensor
+ input_labels = input_labels.to(device)
+
+ return crop_boxes, points_per_crop, cropped_images, input_labels
+
+ def filter_masks(
+ self,
+ masks,
+ iou_scores,
+ original_size,
+ cropped_box_image,
+ pred_iou_thresh=0.88,
+ stability_score_thresh=0.95,
+ mask_threshold=0,
+ stability_score_offset=1,
+ ):
+ """
+ Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
+ that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
+ score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
+ bounding boxes and pad the predicted masks if necessary.
+
+ Args:
+ masks (`torch.Tensor`):
+ Input masks.
+ iou_scores (`torch.Tensor`):
+ List of IoU scores.
+ original_size (`tuple[int,int]`):
+ Size of the original image.
+ cropped_box_image (`torch.Tensor`):
+ The cropped image.
+ pred_iou_thresh (`float`, *optional*, defaults to 0.88):
+ The threshold for the iou scores.
+ stability_score_thresh (`float`, *optional*, defaults to 0.95):
+ The threshold for the stability score.
+ mask_threshold (`float`, *optional*, defaults to 0):
+ The threshold for the predicted masks.
+ stability_score_offset (`float`, *optional*, defaults to 1):
+ The offset for the stability score used in the `_compute_stability_score` method.
+
+ """
+ original_height, original_width = original_size
+ iou_scores = iou_scores.flatten(0, 1)
+ masks = masks.flatten(0, 1)
+
+ if masks.shape[0] != iou_scores.shape[0]:
+ raise ValueError("masks and iou_scores must have the same batch size.")
+
+ if masks.device != iou_scores.device:
+ iou_scores = iou_scores.to(masks.device)
+
+ batch_size = masks.shape[0]
+
+ keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
+
+ if pred_iou_thresh > 0.0:
+ keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
+
+ # compute stability score
+ if stability_score_thresh > 0.0:
+ stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset)
+ keep_mask = keep_mask & (stability_scores > stability_score_thresh)
+
+ scores = iou_scores[keep_mask]
+ masks = masks[keep_mask]
+
+ # binarize masks
+ masks = masks > mask_threshold
+ converted_boxes = _batched_mask_to_box(masks)
+
+ keep_mask = ~_is_box_near_crop_edge(
+ converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
+ )
+
+ scores = scores[keep_mask]
+ masks = masks[keep_mask]
+ converted_boxes = converted_boxes[keep_mask]
+
+ masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
+ # conversion to rle is necessary to run non-maximum suppression
+ masks = _mask_to_rle(masks)
+
+ return masks, scores, converted_boxes
+
+ def post_process_masks(
+ self,
+ masks,
+ original_sizes,
+ reshaped_input_sizes,
+ mask_threshold=0.0,
+ binarize=True,
+ pad_size=None,
+ ):
+ """
+ Remove padding and upscale masks to the original image size.
+
+ Args:
+ masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
+ Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
+ original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
+ The original sizes of each image before it was resized to the model's expected input shape, in (height,
+ width) format.
+ reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
+ The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
+ mask_threshold (`float`, *optional*, defaults to 0.0):
+ The threshold to use for binarizing the masks.
+ binarize (`bool`, *optional*, defaults to `True`):
+ Whether to binarize the masks.
+ pad_size (`int`, *optional*, defaults to `self.pad_size`):
+ The target size the images were padded to before being passed to the model. If None, the target size is
+ assumed to be the processor's `pad_size`.
+ Returns:
+ (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
+ is given by original_size.
+ """
+ pad_size = self.size if pad_size is None else pad_size
+ target_image_size = (pad_size["height"], pad_size["width"])
+ if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
+ original_sizes = original_sizes.tolist()
+ if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
+ reshaped_input_sizes = reshaped_input_sizes.tolist()
+
+ output_masks = []
+ for i, original_size in enumerate(original_sizes):
+ if isinstance(masks[i], np.ndarray):
+ masks[i] = torch.from_numpy(masks[i])
+ elif not isinstance(masks[i], torch.Tensor):
+ raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
+ interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
+ interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
+ interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
+ if binarize:
+ interpolated_mask = interpolated_mask > mask_threshold
+ output_masks.append(interpolated_mask)
+
+ return output_masks
+
+ def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh):
+ """
+ Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
+
+ Args:
+ all_masks (`torch.Tensor`):
+ List of all predicted segmentation masks
+ all_scores (`torch.Tensor`):
+ List of all predicted iou scores
+ all_boxes (`torch.Tensor`):
+ List of all bounding boxes of the predicted masks
+ crops_nms_thresh (`float`):
+ Threshold for NMS (Non Maximum Suppression) algorithm.
+ """
+ return _post_process_for_mask_generation(all_masks, all_scores, all_boxes, crops_nms_thresh)
+
+
+def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
+ # One mask is always contained inside the other.
+ # Save memory by preventing unnecessary cast to torch.int64
+ intersections = (
+ (masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
+ )
+ unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
+ stability_scores = intersections / unions
+ return stability_scores
+
+
+def _mask_to_rle(input_mask: "torch.Tensor"):
+ """
+ Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
+ """
+ # Put in fortran order and flatten height and width
+ batch_size, height, width = input_mask.shape
+ input_mask = input_mask.permute(0, 2, 1).flatten(1)
+
+ # Compute change indices
+ diff = input_mask[:, 1:] ^ input_mask[:, :-1]
+ change_indices = diff.nonzero()
+
+ # Encode run length
+ out = []
+ for i in range(batch_size):
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
+ if len(cur_idxs) == 0:
+ # No changes => either all 0 or all 1
+ # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
+ if input_mask[i, 0] == 0:
+ out.append({"size": [height, width], "counts": [height * width]})
+ else:
+ out.append({"size": [height, width], "counts": [0, height * width]})
+ continue
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+ counts = [] if input_mask[i, 0] == 0 else [0]
+ counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
+ out.append({"size": [height, width], "counts": counts})
+ return out
+
+
+def _batched_mask_to_box(masks: "torch.Tensor"):
+ """
+ Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
+ corresponds the following required indices:
+ - LEFT: left hand side of the bounding box
+ - TOP: top of the bounding box
+ - RIGHT: right of the bounding box
+ - BOTTOM: bottom of the bounding box
+
+ Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
+ is channel_1 x channel_2 x ... x 4.
+
+ Args:
+ - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
+ """
+ # torch.max below raises an error on empty inputs, just skip in this case
+
+ if torch.numel(masks) == 0:
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+ # Normalize shape to Cxheightxwidth
+ shape = masks.shape
+ height, width = shape[-2:]
+
+ # Get top and bottom edges
+ in_height, _ = torch.max(masks, dim=-1)
+ in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+ in_height_coords = in_height_coords + height * (~in_height)
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+ # Get left and right edges
+ in_width, _ = torch.max(masks, dim=-2)
+ in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
+ in_width_coords = in_width_coords + width * (~in_width)
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+ # If the mask is empty the right edge will be to the left of the left edge.
+ # Replace these boxes with [0, 0, 0, 0]
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+ out = out * (~empty_filter).unsqueeze(-1)
+
+ # Return to original shape
+ out = out.reshape(*shape[:-2], 4)
+ return out
+
+
+def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+
+ left, top, _, _ = crop_box
+ offset = torch.tensor([[left, top, left, top]], device=boxes.device)
+ # Check if boxes has a channel dimension
+ if len(boxes.shape) == 3:
+ offset = offset.unsqueeze(1)
+ boxes = (boxes + offset).float()
+
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+ return torch.any(near_crop_edge, dim=1)
+
+
+def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int):
+ left, top, right, bottom = crop_box
+ if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
+ return masks
+ # Coordinate transform masks
+ pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
+ pad = (left, pad_x - left, top, pad_y - top)
+ return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def _generate_crop_boxes(
+ image,
+ target_size: int, # Is it tuple here?
+ crop_n_layers: int = 0,
+ overlap_ratio: float = 512 / 1500,
+ points_per_crop: Optional[int] = 32,
+ crop_n_points_downscale_factor: Optional[list[int]] = 1,
+) -> tuple[list[list[int]], list[int]]:
+ """
+ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
+
+ Args:
+ image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]):
+ Image to generate crops for.
+ target_size (`int`):
+ Size of the smallest crop.
+ crop_n_layers (`int`, *optional*):
+ If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers
+ to run, where each layer has 2**i_layer number of image crops.
+ overlap_ratio (`int`, *optional*):
+ Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the
+ image length. Later layers with more crops scale down this overlap.
+ points_per_crop (`int`, *optional*):
+ Number of points to sample per crop.
+ crop_n_points_downscale_factor (`int`, *optional*):
+ The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+
+ if isinstance(image, list):
+ raise ValueError("Only one image is allowed for crop generation.")
+ original_size = image.shape[-2:]
+
+ points_grid = []
+ for i in range(crop_n_layers + 1):
+ n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
+ points_grid.append(_build_point_grid(n_points))
+
+ crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
+
+ cropped_images, point_grid_per_crop = _generate_crop_images(
+ crop_boxes, image, points_grid, layer_idxs, target_size, original_size
+ )
+ crop_boxes = torch.tensor(crop_boxes)
+ crop_boxes = crop_boxes.float()
+ points_per_crop = torch.stack(point_grid_per_crop)
+ points_per_crop = points_per_crop.unsqueeze(0).permute(0, 2, 1, 3)
+ cropped_images = torch.stack(cropped_images)
+
+ input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.int64)
+
+ return crop_boxes, points_per_crop, cropped_images, input_labels
+
+
+def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
+ """
+ Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format
+ consists of the following required indices:
+ - X: X coordinate of the top left of the bounding box
+ - Y: Y coordinate of the top left of the bounding box
+ - W: width of the bounding box
+ - H: height of the bounding box
+ """
+ crop_boxes, layer_idxs = [], []
+ im_height, im_width = original_size
+ short_side = min(im_height, im_width)
+
+ # Original image
+ crop_boxes.append([0, 0, im_width, im_height])
+ layer_idxs.append(0)
+ for i_layer in range(crop_n_layers):
+ n_crops_per_side = 2 ** (i_layer + 1)
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+ crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
+ crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
+
+ crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
+ crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
+
+ for left, top in product(crop_box_x0, crop_box_y0):
+ box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
+ crop_boxes.append(box)
+ layer_idxs.append(i_layer + 1)
+
+ return crop_boxes, layer_idxs
+
+
+def _build_point_grid(n_per_side: int) -> torch.Tensor:
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+ offset = 1 / (2 * n_per_side)
+ points_one_side = torch.linspace(offset, 1 - offset, n_per_side)
+ points_x = torch.tile(points_one_side[None, :], (n_per_side, 1))
+ points_y = torch.tile(points_one_side[:, None], (1, n_per_side))
+ points = torch.stack([points_x, points_y], dim=-1).reshape(-1, 2)
+ return points
+
+
+def _generate_crop_images(
+ crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
+):
+ """
+ Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
+ also passed.
+ """
+ cropped_images = []
+ total_points_per_crop = []
+ for i, crop_box in enumerate(crop_boxes):
+ left, top, right, bottom = crop_box
+ cropped_im = image[:, top:bottom, left:right]
+
+ cropped_images.append(cropped_im)
+
+ cropped_im_size = cropped_im.shape[-2:]
+ points_scale = torch.tensor(cropped_im_size).flip(dims=(0,)).unsqueeze(0)
+
+ points = points_grid[layer_idxs[i]] * points_scale
+ normalized_points = _normalize_coordinates(target_size, points, original_size)
+ total_points_per_crop.append(normalized_points)
+
+ return cropped_images, total_points_per_crop
+
+
+def _normalize_coordinates(
+ target_size: int, coords: torch.Tensor, original_size: tuple[int, int], is_bounding_box=False
+) -> torch.Tensor:
+ """
+ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width)
+ format.
+ """
+ old_height, old_width = original_size
+
+ scale = target_size * 1.0 / max(old_height, old_width)
+ new_height, new_width = old_height * scale, old_width * scale
+ new_width = int(new_width + 0.5)
+ new_height = int(new_height + 0.5)
+
+ coords = deepcopy(coords).float()
+
+ if is_bounding_box:
+ coords = coords.reshape(-1, 2, 2)
+
+ coords[..., 0] = coords[..., 0] * (new_width / old_width)
+ coords[..., 1] = coords[..., 1] * (new_height / old_height)
+
+ if is_bounding_box:
+ coords = coords.reshape(-1, 4)
+
+ return coords
+
+
+def _rle_to_mask(rle: dict[str, Any]) -> torch.Tensor:
+ """Compute a binary mask from an uncompressed RLE."""
+ height, width = rle["size"]
+ mask = torch.empty(height * width, dtype=bool)
+ idx = 0
+ parity = False
+ for count in rle["counts"]:
+ mask[idx : idx + count] = parity
+ idx += count
+ parity = not parity
+ mask = mask.reshape(width, height)
+ return mask.transpose(0, 1) # Reshape to original shape
+
+
+def _post_process_for_mask_generation(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
+ """
+ Perform NMS (Non Maximum Suppression) on the outputs.
+
+ Args:
+ rle_masks (`torch.Tensor`):
+ binary masks in the RLE format
+ iou_scores (`torch.Tensor` of shape (nb_masks, 1)):
+ iou_scores predicted by the model
+ mask_boxes (`torch.Tensor`):
+ The bounding boxes corresponding to segmentation masks
+ amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
+ NMS threshold.
+ """
+ keep_by_nms = batched_nms(
+ boxes=mask_boxes.float(),
+ scores=iou_scores,
+ idxs=torch.zeros(mask_boxes.shape[0]),
+ iou_threshold=amg_crops_nms_thresh,
+ )
+
+ iou_scores = iou_scores[keep_by_nms]
+ rle_masks = [rle_masks[i] for i in keep_by_nms]
+ mask_boxes = mask_boxes[keep_by_nms]
+ masks = [_rle_to_mask(rle) for rle in rle_masks]
+
+ return masks, iou_scores, rle_masks, mask_boxes
+
+
+__all__ = ["SamImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/modeling_sam.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/modeling_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e2ab424f1a51ca99196ae414af8927fad6a4eae
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/modeling_sam.py
@@ -0,0 +1,1368 @@
+# coding=utf-8
+# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch SAM model."""
+
+import collections
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ ModelOutput,
+ auto_docstring,
+ logging,
+)
+from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
+ layer to the pooler_output.
+ """
+)
+class SamVisionEncoderOutput(ModelOutput):
+ r"""
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to the pooler_output.
+ """
+
+ image_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Segment-Anything model's output
+ """
+)
+class SamImageSegmentationOutput(ModelOutput):
+ r"""
+ iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`):
+ The iou scores of the predicted masks.
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
+ The predicted low resolutions masks. Needs to be post-processed by the processor
+ vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
+ vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ iou_scores: Optional[torch.FloatTensor] = None
+ pred_masks: Optional[torch.FloatTensor] = None
+ vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+class SamPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values):
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+ )
+ embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
+ return embeddings
+
+
+class SamMLPBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
+ self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
+ self.act = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.lin1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.lin2(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam
+class SamLayerNorm(nn.LayerNorm):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
+ super().__init__(normalized_shape, eps=eps, **kwargs)
+ if data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError(f"Unsupported data format: {data_format}")
+ self.data_format = data_format
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
+ """
+ if self.data_format == "channels_first":
+ features = features.permute(0, 2, 3, 1)
+ features = super().forward(features)
+ features = features.permute(0, 3, 1, 2)
+ else:
+ features = super().forward(features)
+ return features
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class SamAttention(nn.Module):
+ """
+ SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
+ values.
+ """
+
+ def __init__(self, config, downsample_rate=None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+
+ downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
+
+ self.internal_dim = config.hidden_size // downsample_rate
+ self.num_attention_heads = config.num_attention_heads
+ if self.internal_dim % config.num_attention_heads != 0:
+ raise ValueError("num_attention_heads must divide hidden_size.")
+ self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5
+
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
+
+ self.is_causal = False
+
+ def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
+ batch, point_batch_size, n_tokens, channel = hidden_states.shape
+ c_per_head = channel // num_attention_heads
+ hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
+ return hidden_states.transpose(1, 2)
+
+ def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
+ batch, n_tokens, n_heads, c_per_head = hidden_states.shape
+ return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ attention_similarity: Optional[Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Tensor:
+ # Input projections
+ query = self.q_proj(query)
+ key = self.k_proj(key)
+ value = self.v_proj(value)
+
+ point_batch_size = query.shape[1]
+ # Separate into heads
+ query = self._separate_heads(query, self.num_attention_heads)
+ key = self._separate_heads(key, self.num_attention_heads)
+ value = self._separate_heads(value, self.num_attention_heads)
+
+ # SamAttention
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=attention_similarity,
+ dropout=0.0,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+
+ attn_output = self._recombine_heads(attn_output, point_batch_size)
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class SamTwoWayAttentionBlock(nn.Module):
+ def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
+ """
+ A transformer block with four layers:
+ (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
+ sparse inputs (4) cross attention of dense inputs -> sparse inputs
+
+ Arguments:
+ config (`SamMaskDecoderConfig`):
+ The configuration file used to instantiate the block
+ attention_downsample_rate (*optionalk*, int, defaults to 2):
+ The downsample ratio of the block used to reduce the inner dim of the attention.
+ skip_first_layer_pe (*optional*, bool, defaults to `False`):
+ Whether or not to skip the addition of the query_point_embedding on the first layer.
+ """
+ super().__init__()
+
+ self.hidden_size = config.hidden_size
+ self.layer_norm_eps = config.layer_norm_eps
+
+ self.self_attn = SamAttention(config, downsample_rate=1)
+ self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
+
+ self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
+ self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
+
+ self.mlp = SamMLPBlock(config)
+ self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
+
+ self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
+ self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(
+ self,
+ queries: Tensor,
+ keys: Tensor,
+ query_point_embedding: Tensor,
+ key_point_embedding: Tensor,
+ attention_similarity: Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries, _ = self.self_attn(query=queries, key=queries, value=queries)
+ else:
+ query = queries + query_point_embedding
+ attn_out, _ = self.self_attn(query=query, key=query, value=queries)
+ queries = queries + attn_out
+ queries = self.layer_norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ query = queries + query_point_embedding
+ key = keys + key_point_embedding
+
+ attn_out, _ = self.cross_attn_token_to_image(
+ query=query, key=key, value=keys, attention_similarity=attention_similarity
+ )
+ queries = queries + attn_out
+
+ queries = self.layer_norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.layer_norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ query = queries + query_point_embedding
+ key = keys + key_point_embedding
+
+ attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
+ keys = keys + attn_out
+
+ keys = self.layer_norm4(keys)
+ return queries, keys, attn_out
+
+
+class SamTwoWayTransformer(nn.Module):
+ def __init__(self, config: SamMaskDecoderConfig):
+ super().__init__()
+ self.config = config
+
+ self.num_hidden_layers = config.num_hidden_layers
+ self.layers = nn.ModuleList()
+
+ for i in range(self.num_hidden_layers):
+ self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
+
+ self.final_attn_token_to_image = SamAttention(config)
+ self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ point_embeddings: Tensor,
+ image_embeddings: Tensor,
+ image_positional_embeddings: Tensor,
+ attention_similarity: Tensor,
+ target_embedding=None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutput]:
+ if image_embeddings is None:
+ raise ValueError("You have to specify an image_embedding")
+
+ image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
+ image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
+
+ # Prepare queries
+ queries = point_embeddings
+ keys = image_embeddings
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ if target_embedding is not None:
+ queries += target_embedding
+
+ queries, keys, _ = layer(
+ queries=queries,
+ keys=keys,
+ query_point_embedding=point_embeddings,
+ key_point_embedding=image_positional_embeddings,
+ attention_similarity=attention_similarity,
+ **kwargs,
+ )
+ # Apply the final attention layer from the points to the image
+ query = queries + point_embeddings
+ key = keys + image_positional_embeddings
+
+ attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
+
+ queries = queries + attn_out
+ queries = self.layer_norm_final_attn(queries)
+ return queries, keys
+
+
+class SamFeedForward(nn.Module):
+ def __init__(
+ self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
+ ):
+ super().__init__()
+ self.num_layers = num_layers
+ self.activation = nn.ReLU()
+ self.proj_in = nn.Linear(input_dim, hidden_dim)
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
+ self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
+ self.sigmoid_output = sigmoid_output
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ for layer in self.layers:
+ hidden_states = self.activation(layer(hidden_states))
+
+ hidden_states = self.proj_out(hidden_states)
+ if self.sigmoid_output:
+ hidden_states = F.sigmoid(hidden_states)
+ return hidden_states
+
+
+class SamMaskDecoder(nn.Module):
+ def __init__(self, config: SamMaskDecoderConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+
+ self.num_multimask_outputs = config.num_multimask_outputs
+ self.num_mask_tokens = config.num_multimask_outputs + 1
+
+ self.iou_token = nn.Embedding(1, self.hidden_size)
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
+
+ self.transformer = SamTwoWayTransformer(config)
+
+ # should we create a new class for this?
+ self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
+ self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
+ self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, data_format="channels_first")
+ self.activation = nn.GELU()
+
+ mlps_list = []
+ for _ in range(self.num_mask_tokens):
+ mlps_list += [SamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
+ self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
+
+ self.iou_prediction_head = SamFeedForward(
+ self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth
+ )
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_positional_embeddings: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ attention_similarity: Optional[torch.Tensor] = None,
+ target_embedding: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Args:
+ image_embeddings (`torch.Tensor`):
+ the embeddings from the image encoder
+ image_positional_embedding (`torch.Tensor`):
+ positional encoding with the shape of image_embeddings
+ sparse_prompt_embeddings (`torch.Tensor`):
+ The embeddings of the points and boxes
+ dense_prompt_embeddings (`torch.Tensor`):
+ the embeddings of the mask inputs
+ multimask_output (bool):
+ Whether to return multiple masks or a single mask.
+ """
+ batch_size, num_channels, height, width = image_embeddings.shape
+ point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
+ # Concatenate output tokens
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
+ output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
+
+ if sparse_prompt_embeddings is not None:
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
+ else:
+ tokens = output_tokens
+ point_embeddings = tokens.to(self.iou_token.weight.dtype)
+
+ # Expand per-image data in batch direction to be per-point
+ image_embeddings = image_embeddings + dense_prompt_embeddings
+ image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
+ image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
+
+ # Run the transformer, image_positional_embedding are consumed
+ point_embedding, image_embeddings = self.transformer(
+ point_embeddings=point_embeddings,
+ image_embeddings=image_embeddings,
+ image_positional_embeddings=image_positional_embeddings,
+ attention_similarity=attention_similarity,
+ target_embedding=target_embedding,
+ )
+ iou_token_out = point_embedding[:, :, 0, :]
+ mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ image_embeddings = image_embeddings.transpose(2, 3).reshape(
+ batch_size * point_batch_size, num_channels, height, width
+ )
+
+ upscaled_embedding = self.upscale_conv1(image_embeddings)
+ upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
+ upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
+
+ hyper_in_list = []
+ for i in range(self.num_mask_tokens):
+ current_mlp = self.output_hypernetworks_mlps[i]
+ hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
+ hyper_in = torch.stack(hyper_in_list, dim=2)
+
+ _, num_channels, height, width = upscaled_embedding.shape
+ upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width)
+ masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ mask_slice = slice(1, None)
+ else:
+ mask_slice = slice(0, 1)
+ masks = masks[:, :, mask_slice, :, :]
+ iou_pred = iou_pred[:, :, mask_slice]
+ return masks, iou_pred
+
+
+class SamPositionalEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.scale = config.hidden_size // 2
+ self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))
+
+ def forward(self, input_coords, input_shape=None):
+ """Positionally encode points that are normalized to [0,1]."""
+ coordinates = input_coords.clone()
+
+ if input_shape is not None:
+ coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
+ coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
+
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coordinates = 2 * coordinates - 1
+ coordinates = coordinates.to(self.positional_embedding.dtype)
+ coordinates = coordinates @ self.positional_embedding
+ coordinates = 2 * np.pi * coordinates
+ # outputs d_1 x ... x d_n x channel shape
+ return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
+
+
+class SamMaskEmbedding(nn.Module):
+ def __init__(self, config: SamPromptEncoderConfig):
+ super().__init__()
+ self.mask_input_channels = config.mask_input_channels // 4
+ self.activation = ACT2FN[config.hidden_act]
+ self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
+ self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
+ self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
+ self.layer_norm1 = SamLayerNorm(
+ self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
+ )
+ self.layer_norm2 = SamLayerNorm(
+ self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
+ )
+
+ def forward(self, masks):
+ hidden_states = self.conv1(masks)
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ dense_embeddings = self.conv3(hidden_states)
+ return dense_embeddings
+
+
+class SamPromptEncoder(nn.Module):
+ def __init__(self, config: SamConfig):
+ super().__init__()
+ self.shared_embedding = SamPositionalEmbedding(config.vision_config)
+ config = config.prompt_encoder_config
+ self.mask_embed = SamMaskEmbedding(config)
+ self.no_mask_embed = nn.Embedding(1, config.hidden_size)
+
+ self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
+ self.input_image_size = config.image_size
+
+ self.point_embed = nn.ModuleList(
+ [nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)]
+ )
+ self.hidden_size = config.hidden_size
+ self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
+
+ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
+ target_labels_shape = (points.shape[0], points.shape[1], 1)
+ padding_point = torch.zeros(target_point_shape, device=points.device)
+ padding_label = -torch.ones(target_labels_shape, device=labels.device)
+ points = torch.cat([points, padding_point], dim=2)
+ labels = torch.cat([labels, padding_label], dim=2)
+ input_shape = (self.input_image_size, self.input_image_size)
+ point_embedding = self.shared_embedding(points, input_shape)
+
+ # torch.where and expanding the labels tensor is required by the ONNX export
+ point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
+
+ # This is required for the ONNX export. The dtype, device need to be explicitly
+ # specified as otherwise torch.onnx.export interprets as double
+ point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding))
+
+ point_embedding = torch.where(
+ (labels == 0)[:, :, :, None],
+ point_embedding + self.point_embed[0].weight[None, None, :, :],
+ point_embedding,
+ )
+
+ point_embedding = torch.where(
+ (labels == 1)[:, :, :, None],
+ point_embedding + self.point_embed[1].weight[None, None, :, :],
+ point_embedding,
+ )
+
+ return point_embedding
+
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts."""
+ boxes = boxes + 0.5 # Shift to center of pixel
+ batch_size, nb_boxes = boxes.shape[:2]
+ coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
+ input_shape = (self.input_image_size, self.input_image_size)
+ corner_embedding = self.shared_embedding(coords, input_shape)
+ corner_embedding[:, :, 0, :] += self.point_embed[2].weight
+ corner_embedding[:, :, 1, :] += self.point_embed[3].weight
+ return corner_embedding
+
+ def forward(
+ self,
+ input_points: Optional[tuple[torch.Tensor, torch.Tensor]],
+ input_labels: Optional[torch.Tensor],
+ input_boxes: Optional[torch.Tensor],
+ input_masks: Optional[torch.Tensor],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Embeds different types of prompts, returning both sparse and dense embeddings.
+
+ Args:
+ points (`torch.Tensor`, *optional*):
+ point coordinates and labels to embed.
+ boxes (`torch.Tensor`, *optional*):
+ boxes to embed
+ masks (`torch.Tensor`, *optional*):
+ masks to embed
+ """
+ sparse_embeddings = None
+ batch_size = 1
+ if input_points is not None:
+ batch_size = input_points.shape[0]
+ if input_labels is None:
+ raise ValueError("If points are provided, labels must also be provided.")
+ point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
+ sparse_embeddings = point_embeddings
+ if input_boxes is not None:
+ batch_size = input_boxes.shape[0]
+ box_embeddings = self._embed_boxes(input_boxes)
+ if sparse_embeddings is None:
+ sparse_embeddings = box_embeddings
+ else:
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
+ if input_masks is not None:
+ dense_embeddings = self.mask_embed(input_masks)
+ else:
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+ batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+ )
+
+ return sparse_embeddings, dense_embeddings
+
+
+class SamVisionAttention(nn.Module):
+ """Multi-head Attention block with relative position embeddings."""
+
+ def __init__(self, config, window_size):
+ super().__init__()
+ input_size = (
+ (config.image_size // config.patch_size, config.image_size // config.patch_size)
+ if window_size == 0
+ else (window_size, window_size)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ head_dim = config.hidden_size // config.num_attention_heads
+ self.scale = head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size)
+
+ self.use_rel_pos = config.use_rel_pos
+ if self.use_rel_pos:
+ if input_size is None:
+ raise ValueError("Input size must be provided if using relative positional encoding.")
+
+ # initialize relative positional embeddings
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+
+ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+ """
+ Get relative positional embeddings according to the relative positions of
+ query and key sizes.
+
+ Args:
+ q_size (int):
+ size of the query.
+ k_size (int):
+ size of key k.
+ rel_pos (`torch.Tensor`):
+ relative position embeddings (L, channel).
+
+ Returns:
+ Extracted positional embeddings according to relative positions.
+ """
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos.
+ rel_pos_resized = F.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+ size=max_rel_dist,
+ mode="linear",
+ )
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+ return rel_pos_resized[relative_coords.long()]
+
+ def get_decomposed_rel_pos(
+ self,
+ query: torch.Tensor,
+ rel_pos_h: torch.Tensor,
+ rel_pos_w: torch.Tensor,
+ q_size: tuple[int, int],
+ k_size: tuple[int, int],
+ ) -> torch.Tensor:
+ """
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
+
+ Args:
+ query (`torch.Tensor`):
+ query q in the attention layer with shape (batch_size, query_height * query_width, channel).
+ rel_pos_h (`torch.Tensor`):
+ relative position embeddings (Lh, channel) for height axis.
+ rel_pos_w (`torch.Tensor`):
+ relative position embeddings (Lw, channel) for width axis.
+ q_size (tuple):
+ spatial sequence size of query q with (query_height, query_width).
+ k_size (tuple):
+ spatial sequence size of key k with (key_height, key_width).
+
+ Returns:
+ decomposed_rel_pos (`torch.Tensor`):
+ decomposed relative position embeddings.
+ """
+ query_height, query_width = q_size
+ key_height, key_width = k_size
+ relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
+ relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
+
+ batch_size, _, dim = query.shape
+ reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
+ rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
+
+ decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
+
+ return decomposed_rel_pos
+
+ def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size, height, width, _ = hidden_states.shape
+ # qkv with shape (3, batch_size, nHead, height * width, channel)
+ qkv = (
+ self.qkv(hidden_states)
+ .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
+ .permute(2, 0, 3, 1, 4)
+ )
+ # q, k, v with shape (batch_size * nHead, height * width, channel)
+ query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
+
+ attn_weights = (query * self.scale) @ key.transpose(-2, -1)
+
+ if self.use_rel_pos:
+ decomposed_rel_pos = self.get_decomposed_rel_pos(
+ query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
+ )
+ decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
+ attn_weights = attn_weights + decomposed_rel_pos
+
+ attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
+ attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
+
+ attn_output = self.proj(attn_output)
+ return attn_output, attn_weights
+
+
+class SamVisionSdpaAttention(SamVisionAttention):
+ """
+ Multi-head Attention block with relative position embeddings.
+ Using SDPA instead of the default attention.
+ """
+
+ def __init__(self, config, window_size):
+ super().__init__(config, window_size)
+
+ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
+ if output_attentions:
+ logger.warning_once(
+ "`SamVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
+ "`output_attentions=True`. Falling back to the manual attention implementation, but "
+ "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ output_attentions=output_attentions,
+ )
+
+ batch_size, height, width, _ = hidden_states.shape
+ # qkv with shape (3, B, nHead, H * W, C)
+ qkv = (
+ self.qkv(hidden_states)
+ .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
+ .permute(2, 0, 3, 1, 4)
+ )
+ # q, k, v with shape (B * nHead, H * W, C)
+ query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
+
+ attn_bias = None
+ if self.use_rel_pos:
+ decomposed_rel_pos = self.get_decomposed_rel_pos(
+ query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
+ )
+ decomposed_rel_pos = decomposed_rel_pos.reshape(
+ batch_size, self.num_attention_heads, height * width, height * width
+ )
+ attn_bias = decomposed_rel_pos
+
+ query = query.view(batch_size, self.num_attention_heads, height * width, -1)
+ key = key.view(batch_size, self.num_attention_heads, height * width, -1)
+ value = value.view(batch_size, self.num_attention_heads, height * width, -1)
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
+
+ attn_output = (
+ attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
+ .permute(0, 2, 3, 1, 4)
+ .reshape(batch_size, height, width, -1)
+ )
+
+ attn_output = self.proj(attn_output)
+ return attn_output, None
+
+
+SAM_VISION_ATTENTION_CLASSES = {
+ "eager": SamVisionAttention,
+ "sdpa": SamVisionSdpaAttention,
+}
+
+
+class SamVisionLayer(GradientCheckpointingLayer):
+ def __init__(self, config, window_size):
+ super().__init__()
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size)
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = SamMLPBlock(config)
+ self.window_size = window_size
+
+ def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> tuple[torch.Tensor, tuple[int, int]]:
+ """
+ Args:
+ Partition into non-overlapping windows with padding if needed.
+ hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window
+ size.
+
+ Returns:
+ windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel].
+ (pad_height, pad_width): padded height and width before partition
+ """
+ batch_size, height, width, channel = hidden_states.shape
+
+ pad_h = (window_size - height % window_size) % window_size
+ pad_w = (window_size - width % window_size) % window_size
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
+ pad_height, pad_width = height + pad_h, width + pad_w
+
+ hidden_states = hidden_states.reshape(
+ batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel
+ )
+ windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel)
+ return windows, (pad_height, pad_width)
+
+ def window_unpartition(
+ self, windows: torch.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int]
+ ) -> torch.Tensor:
+ """
+ Args:
+ Window unpartition into original sequences and removing padding.
+ hidden_states (tensor):
+ input tokens with [batch_size * num_windows, window_size, window_size, channel].
+ window_size (int):
+ window size.
+ padding_shape (Tuple):
+ padded height and width (pad_height, pad_width).
+ original_shape (Tuple): original height and width (height, width) before padding.
+
+ Returns:
+ hidden_states: unpartitioned sequences with [batch_size, height, width, channel].
+ """
+ pad_height, pad_width = padding_shape
+ height, width = original_shape
+ batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size)
+ hidden_states = windows.reshape(
+ batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1
+ )
+ hidden_states = (
+ hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
+ )
+
+ hidden_states = hidden_states[:, :height, :width, :].contiguous()
+ return hidden_states
+
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]:
+ residual = hidden_states
+ hidden_states = self.layer_norm1(hidden_states)
+ # Window partition
+ if self.window_size > 0:
+ height, width = hidden_states.shape[1], hidden_states.shape[2]
+ hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
+
+ hidden_states, attn_weights = self.attn(
+ hidden_states=hidden_states,
+ )
+ # Reverse window partition
+ if self.window_size > 0:
+ hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
+
+ hidden_states = residual + hidden_states
+ layernorm_output = self.layer_norm2(hidden_states)
+ hidden_states = hidden_states + self.mlp(layernorm_output)
+ return hidden_states
+
+
+class SamVisionNeck(nn.Module):
+ def __init__(self, config: SamVisionConfig):
+ super().__init__()
+ self.config = config
+
+ self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
+ self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")
+ self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False)
+ self.layer_norm2 = SamLayerNorm(config.output_channels, data_format="channels_first")
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = self.layer_norm1(hidden_states)
+
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.layer_norm2(hidden_states)
+ return hidden_states
+
+
+@auto_docstring
+class SamPreTrainedModel(PreTrainedModel):
+ config: SamConfig
+ base_model_prefix = "sam"
+ main_input_name = "pixel_values"
+ _no_split_modules = ["SamVisionAttention"]
+ supports_gradient_checkpointing = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module: nn.Module):
+ super()._init_weights(module)
+ if isinstance(module, SamVisionAttention):
+ if module.use_rel_pos:
+ module.rel_pos_h.data.zero_()
+ module.rel_pos_w.data.zero_()
+ elif isinstance(module, SamVisionEncoder):
+ if self.config.use_abs_pos:
+ module.pos_embed.data.zero_()
+
+
+class SamVisionEncoder(SamPreTrainedModel):
+ _can_record_outputs = {"hidden_states": SamVisionLayer, "attentions": SamVisionAttention}
+
+ def __init__(self, config: SamVisionConfig):
+ super().__init__(config)
+ self.config = config
+ self.image_size = config.image_size
+ self.patch_embed = SamPatchEmbeddings(config)
+
+ self.pos_embed = None
+ if config.use_abs_pos:
+ # Initialize absolute positional embedding with pretrain image size.
+ self.pos_embed = nn.Parameter(
+ torch.zeros(
+ 1,
+ config.image_size // config.patch_size,
+ config.image_size // config.patch_size,
+ config.hidden_size,
+ )
+ )
+
+ self.layers = nn.ModuleList()
+ for i in range(config.num_hidden_layers):
+ layer = SamVisionLayer(
+ config,
+ window_size=config.window_size if i not in config.global_attn_indexes else 0,
+ )
+ self.layers.append(layer)
+
+ self.neck = SamVisionNeck(config)
+
+ self.gradient_checkpointing = False
+
+ def get_input_embeddings(self):
+ return self.patch_embed
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ def forward(
+ self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs]
+ ) -> SamVisionEncoderOutput:
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.patch_embed(pixel_values)
+ if self.pos_embed is not None:
+ hidden_states = hidden_states + self.pos_embed
+ for layer_module in self.layers:
+ hidden_states = layer_module(hidden_states)
+ hidden_states = self.neck(hidden_states)
+ return SamVisionEncoderOutput(
+ last_hidden_state=hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The vision model from Sam without any head or projection on top.
+ """
+)
+class SamVisionModel(SamPreTrainedModel):
+ config: SamVisionConfig
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: SamVisionConfig):
+ super().__init__(config)
+ self.vision_encoder = SamVisionEncoder(config)
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_encoder.patch_embed
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, SamVisionEncoderOutput]:
+ return self.vision_encoder(pixel_values, **kwargs)
+
+
+@auto_docstring(
+ custom_intro="""
+ Segment Anything Model (SAM) for generating segmentation masks, given an input image and
+ input points and labels, boxes, or masks.
+ """
+)
+class SamModel(SamPreTrainedModel):
+ _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
+ # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
+ _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
+ _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(SamTwoWayAttentionBlock, index=2)}
+
+ def __init__(self, config: SamConfig):
+ super().__init__(config)
+ self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
+
+ self.vision_encoder = SamVisionEncoder(config.vision_config)
+ self.prompt_encoder = SamPromptEncoder(config)
+ # The module using it is not a PreTrainedModel subclass so we need this
+ config.mask_decoder_config._attn_implementation = config._attn_implementation
+ self.mask_decoder = SamMaskDecoder(config.mask_decoder_config)
+
+ self.post_init()
+
+ def _tie_weights(self):
+ self.prompt_encoder.shared_embedding.positional_embedding.data = (
+ self.shared_image_embedding.positional_embedding.data
+ )
+
+ def get_input_embeddings(self):
+ return self.vision_encoder.get_input_embeddings()
+
+ def get_image_wide_positional_embeddings(self):
+ size = self.config.prompt_encoder_config.image_embedding_size
+ target_device = self.shared_image_embedding.positional_embedding.device
+ target_dtype = self.shared_image_embedding.positional_embedding.dtype
+ grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / size
+ x_embed = x_embed / size
+
+ positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
+ return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
+
+ @torch.no_grad()
+ def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs]):
+ r"""
+ Returns the image embeddings by passing the pixel values through the vision encoder.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Input pixel values
+ """
+ vision_output = self.vision_encoder(
+ pixel_values,
+ **kwargs,
+ )
+ image_embeddings = vision_output[0]
+ return image_embeddings
+
+ @torch.no_grad()
+ def get_prompt_embeddings(
+ self,
+ input_points: Optional[torch.FloatTensor] = None,
+ input_labels: Optional[torch.LongTensor] = None,
+ input_boxes: Optional[torch.FloatTensor] = None,
+ input_masks: Optional[torch.LongTensor] = None,
+ ):
+ r"""
+ Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
+
+ Args:
+ input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
+ Optional input points for the prompt encoder. The padding of the point is automatically done by the
+ processor. `point_batch_size` refers to the number of masks that we want the model to predict per
+ point. The model will output `point_batch_size` times 3 masks in total.
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
+ Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
+ processor, or can be fed by the user.
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
+ Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
+ processor. users can also pass manually the input boxes.
+ input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
+ Optional input masks for the prompt encoder.
+ """
+ prompt_output = self.prompt_encoder(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+ return prompt_output
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ input_points: Optional[torch.FloatTensor] = None,
+ input_labels: Optional[torch.LongTensor] = None,
+ input_boxes: Optional[torch.FloatTensor] = None,
+ input_masks: Optional[torch.LongTensor] = None,
+ image_embeddings: Optional[torch.FloatTensor] = None,
+ multimask_output: bool = True,
+ attention_similarity: Optional[torch.FloatTensor] = None,
+ target_embedding: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> SamImageSegmentationOutput:
+ r"""
+ input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
+ Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
+ better results. The points can be obtained by passing a list of list of list to the processor that will
+ create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
+ second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
+ per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
+ multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
+ coordinates of the point. If a different number of points is passed either for each image, or for each
+ mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
+ computation of the embedding will be skipped for these points using the labels.
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
+ Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
+ official implementation, there are 3 types of labels
+
+ - `1`: the point is a point that contains the object of interest
+ - `0`: the point is a point that does not contain the object of interest
+ - `-1`: the point corresponds to the background
+
+ We added the label:
+
+ - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
+
+ The padding labels should be automatically done by the processor.
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
+ Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
+ much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
+ that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
+ size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
+ In the order (`x1`, `y1`, `x2`, `y2`):
+
+ - `x1`: the x coordinate of the top left point of the input box
+ - `y1`: the y coordinate of the top left point of the input box
+ - `x2`: the x coordinate of the bottom right point of the input box
+ - `y2`: the y coordinate of the bottom right point of the input box
+ input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
+ SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
+ generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
+ manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
+ image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
+ Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
+ efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
+ method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
+ multimask_output (`bool`, *optional*):
+ In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
+ bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
+ "best" mask, by specifying `multimask_output=False`.
+ attention_similarity (`torch.FloatTensor`, *optional*):
+ Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
+ model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
+ target_embedding (`torch.FloatTensor`, *optional*):
+ Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
+ the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoModel, AutoProcessor
+
+ >>> model = AutoModel.from_pretrained("facebook/sam-vit-base")
+ >>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
+
+ >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
+ >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
+ >>> input_points = [[[400, 650]]] # 2D location of a window on the car
+ >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
+
+ >>> # Get segmentation mask
+ >>> outputs = model(**inputs)
+
+ >>> # Postprocess masks
+ >>> masks = processor.post_process_masks(
+ ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
+ ... )
+ ```
+ """
+ if pixel_values is None and image_embeddings is None:
+ raise ValueError("Either pixel_values or image_embeddings must be provided.")
+
+ if pixel_values is not None and image_embeddings is not None:
+ raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
+
+ if input_points is not None and len(input_points.shape) != 4:
+ raise ValueError(
+ "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
+ f" got {input_points.shape}.",
+ )
+ if input_boxes is not None and len(input_boxes.shape) != 3:
+ raise ValueError(
+ "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
+ f" got {input_boxes.shape}.",
+ )
+ if input_points is not None and input_boxes is not None:
+ point_batch_size = input_points.shape[1]
+ box_batch_size = input_boxes.shape[1]
+ if point_batch_size != box_batch_size:
+ raise ValueError(
+ f"You should provide as many bounding boxes as input points per box. Got {point_batch_size} and {box_batch_size}."
+ )
+
+ image_positional_embeddings = self.get_image_wide_positional_embeddings()
+ # repeat with batch size
+ batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
+ image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
+
+ vision_attentions = None
+ vision_hidden_states = None
+
+ if pixel_values is not None:
+ vision_outputs: SamVisionEncoderOutput = self.vision_encoder(pixel_values, **kwargs)
+ image_embeddings = vision_outputs.last_hidden_state
+ vision_hidden_states = vision_outputs.hidden_states
+ vision_attentions = vision_outputs.attentions
+
+ if input_points is not None and input_labels is None:
+ input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
+
+ if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
+ raise ValueError(
+ "The batch size of the image embeddings and the input points must be the same. ",
+ f"Got {image_embeddings.shape[0]} and {input_points.shape[0]} respectively.",
+ " if you want to pass multiple points for the same image, make sure that you passed ",
+ " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
+ " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
+ )
+
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+
+ low_res_masks, iou_predictions = self.mask_decoder(
+ image_embeddings=image_embeddings,
+ image_positional_embeddings=image_positional_embeddings,
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ attention_similarity=attention_similarity,
+ target_embedding=target_embedding,
+ )
+
+ return SamImageSegmentationOutput(
+ iou_scores=iou_predictions,
+ pred_masks=low_res_masks,
+ vision_hidden_states=vision_hidden_states,
+ vision_attentions=vision_attentions,
+ )
+
+
+__all__ = ["SamVisionModel", "SamModel", "SamPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/modeling_tf_sam.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/modeling_tf_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac81288fa182b027752e16f0ff2525803f58a8b7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/modeling_tf_sam.py
@@ -0,0 +1,1723 @@
+# coding=utf-8
+# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a
+discrepancy, the original file should be regarded as the 'reference' version.
+"""
+
+from __future__ import annotations
+
+import collections
+from dataclasses import dataclass
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import ACT2FN
+from ...modeling_tf_outputs import TFBaseModelOutput
+from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs
+from ...tf_utils import flatten, functional_layernorm
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "SamConfig"
+_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"
+
+
+@dataclass
+class TFSamVisionEncoderOutput(ModelOutput):
+ """
+ Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
+ layer to the pooler_output.
+
+ Args:
+ image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to the pooler_output.
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+ the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ image_embeds: tf.Tensor | None = None
+ last_hidden_state: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor, ...] | None = None
+ attentions: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFSamImageSegmentationOutput(ModelOutput):
+ """
+ Base class for Segment-Anything model's output
+
+ Args:
+ iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`):
+ The iou scores of the predicted masks.
+ pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`):
+ The predicted low resolutions masks. Needs to be post-processed by the processor
+ vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+ the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
+ vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ iou_scores: tf.Tensor | None = None
+ pred_masks: tf.Tensor | None = None
+ vision_hidden_states: tuple[tf.Tensor, ...] | None = None
+ vision_attentions: tuple[tf.Tensor, ...] | None = None
+ mask_decoder_attentions: tuple[tf.Tensor, ...] | None = None
+
+
+class TFSamPatchEmbeddings(keras.layers.Layer):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = keras.layers.Conv2D(
+ hidden_size, kernel_size=patch_size, strides=patch_size, name="projection"
+ )
+
+ def call(self, pixel_values):
+ batch_size, num_channels, height, width = shape_list(pixel_values)
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+ )
+ embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1]))
+ return embeddings
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "projection", None) is not None:
+ with tf.name_scope(self.projection.name):
+ self.projection.build([None, None, None, self.num_channels])
+
+
+class TFSamMLPBlock(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ self.lin1 = keras.layers.Dense(config.mlp_dim, name="lin1")
+ self.lin2 = keras.layers.Dense(config.hidden_size, name="lin2")
+ self.act = ACT2FN[config.hidden_act]
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.lin1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.lin2(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "lin1", None) is not None:
+ with tf.name_scope(self.lin1.name):
+ self.lin1.build([None, None, self.config.hidden_size])
+ if getattr(self, "lin2", None) is not None:
+ with tf.name_scope(self.lin2.name):
+ self.lin2.build([None, None, self.config.mlp_dim])
+
+
+class TFSamLayerNorm(keras.layers.Layer):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs):
+ super().__init__(**kwargs)
+ self.eps = eps
+ self.data_format = data_format
+ self.normalized_shape = normalized_shape
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError(f"Unsupported data format: {self.data_format}")
+
+ def build(self, input_shape):
+ self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight")
+ self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias")
+ super().build(input_shape)
+
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ if self.data_format == "channels_last":
+ x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1)
+ elif self.data_format == "channels_first":
+ x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1)
+ return x
+
+
+class TFSamAttention(keras.layers.Layer):
+ """
+ SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
+ values.
+ """
+
+ def __init__(self, config, downsample_rate=None, **kwargs):
+ super().__init__(**kwargs)
+ self.hidden_size = config.hidden_size
+
+ downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
+
+ self.internal_dim = config.hidden_size // downsample_rate
+ self.num_attention_heads = config.num_attention_heads
+ if self.internal_dim % config.num_attention_heads != 0:
+ raise ValueError("num_attention_heads must divide hidden_size.")
+
+ self.q_proj = keras.layers.Dense(self.internal_dim, name="q_proj")
+ self.k_proj = keras.layers.Dense(self.internal_dim, name="k_proj")
+ self.v_proj = keras.layers.Dense(self.internal_dim, name="v_proj")
+ self.out_proj = keras.layers.Dense(self.hidden_size, name="out_proj")
+
+ def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor:
+ batch, point_batch_size, n_tokens, channel = shape_list(hidden_states)
+ c_per_head = channel // num_attention_heads
+ hidden_states = tf.reshape(
+ hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
+ )
+ return tf.transpose(hidden_states, perm=[0, 2, 1, 3])
+
+ def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor:
+ batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
+ hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
+ return tf.reshape(
+ hidden_states,
+ (batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head),
+ )
+
+ def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
+ # Input projections
+ query = self.q_proj(query)
+ key = self.k_proj(key)
+ value = self.v_proj(value)
+
+ point_batch_size = shape_list(query)[1]
+ # Separate into heads
+ query = self._separate_heads(query, self.num_attention_heads)
+ key = self._separate_heads(key, self.num_attention_heads)
+ value = self._separate_heads(value, self.num_attention_heads)
+
+ # SamAttention
+ _, _, _, c_per_head = shape_list(query)
+ attn = tf.matmul(
+ query, tf.transpose(key, perm=[0, 1, 3, 2])
+ ) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
+ attn = attn / tf.math.sqrt(float(c_per_head))
+ attn = tf.nn.softmax(attn, axis=-1)
+
+ # Get output
+ out = tf.matmul(attn, value)
+ out = self._recombine_heads(out, point_batch_size)
+ out = self.out_proj(out)
+
+ return out
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "q_proj", None) is not None:
+ with tf.name_scope(self.q_proj.name):
+ self.q_proj.build([None, None, self.hidden_size])
+ if getattr(self, "k_proj", None) is not None:
+ with tf.name_scope(self.k_proj.name):
+ self.k_proj.build([None, None, self.hidden_size])
+ if getattr(self, "v_proj", None) is not None:
+ with tf.name_scope(self.v_proj.name):
+ self.v_proj.build([None, None, self.hidden_size])
+ if getattr(self, "out_proj", None) is not None:
+ with tf.name_scope(self.out_proj.name):
+ self.out_proj.build([None, None, self.internal_dim])
+
+
+class TFSamTwoWayAttentionBlock(keras.layers.Layer):
+ def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs):
+ """
+ A transformer block with four layers:
+ (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
+ sparse inputs (4) cross attention of dense inputs -> sparse inputs
+
+ Arguments:
+ config (`SamMaskDecoderConfig`):
+ The configuration file used to instantiate the block
+ attention_downsample_rate (*optionalk*, int, defaults to 2):
+ The downsample ratio of the block used to reduce the inner dim of the attention.
+ skip_first_layer_pe (*optional*, bool, defaults to `False`):
+ Whether or not to skip the addition of the query_point_embedding on the first layer.
+ """
+ super().__init__(**kwargs)
+
+ self.hidden_size = config.hidden_size
+ self.layer_norm_eps = config.layer_norm_eps
+
+ self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn")
+ self.layer_norm1 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1")
+
+ self.cross_attn_token_to_image = TFSamAttention(
+ config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image"
+ )
+ self.layer_norm2 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2")
+
+ self.mlp = TFSamMLPBlock(config, name="mlp")
+ self.layer_norm3 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3")
+
+ self.layer_norm4 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4")
+ self.cross_attn_image_to_token = TFSamAttention(
+ config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token"
+ )
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def call(
+ self,
+ queries: tf.Tensor,
+ keys: tf.Tensor,
+ query_point_embedding: tf.Tensor,
+ key_point_embedding: tf.Tensor,
+ output_attentions: bool = False,
+ ):
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries = self.self_attn(query=queries, key=queries, value=queries)
+ else:
+ query = queries + query_point_embedding
+ attn_out = self.self_attn(query=query, key=query, value=queries)
+ queries = queries + attn_out
+ queries = self.layer_norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ query = queries + query_point_embedding
+ key = keys + key_point_embedding
+
+ attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
+ queries = queries + attn_out
+
+ queries = self.layer_norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.layer_norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ query = queries + query_point_embedding
+ key = keys + key_point_embedding
+
+ attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
+ keys = keys + attn_out
+
+ keys = self.layer_norm4(keys)
+
+ outputs = (queries, keys)
+
+ if output_attentions:
+ outputs = outputs + (attn_out,)
+ else:
+ outputs = outputs + (None,)
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "self_attn", None) is not None:
+ with tf.name_scope(self.self_attn.name):
+ self.self_attn.build(None)
+ if getattr(self, "layer_norm1", None) is not None:
+ with tf.name_scope(self.layer_norm1.name):
+ self.layer_norm1.build([None, None, None, self.hidden_size])
+ if getattr(self, "cross_attn_token_to_image", None) is not None:
+ with tf.name_scope(self.cross_attn_token_to_image.name):
+ self.cross_attn_token_to_image.build(None)
+ if getattr(self, "layer_norm2", None) is not None:
+ with tf.name_scope(self.layer_norm2.name):
+ self.layer_norm2.build([None, None, None, self.hidden_size])
+ if getattr(self, "mlp", None) is not None:
+ with tf.name_scope(self.mlp.name):
+ self.mlp.build(None)
+ if getattr(self, "layer_norm3", None) is not None:
+ with tf.name_scope(self.layer_norm3.name):
+ self.layer_norm3.build([None, None, None, self.hidden_size])
+ if getattr(self, "layer_norm4", None) is not None:
+ with tf.name_scope(self.layer_norm4.name):
+ self.layer_norm4.build([None, None, None, self.hidden_size])
+ if getattr(self, "cross_attn_image_to_token", None) is not None:
+ with tf.name_scope(self.cross_attn_image_to_token.name):
+ self.cross_attn_image_to_token.build(None)
+
+
+class TFSamTwoWayTransformer(keras.layers.Layer):
+ def __init__(self, config: SamMaskDecoderConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+
+ self.num_hidden_layers = config.num_hidden_layers
+ self.layers = []
+
+ for i in range(self.num_hidden_layers):
+ self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}"))
+
+ self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image")
+ self.layer_norm_final_attn = keras.layers.LayerNormalization(
+ epsilon=config.layer_norm_eps, name="layer_norm_final_attn"
+ )
+
+ def call(
+ self,
+ point_embeddings: tf.Tensor,
+ image_embeddings: tf.Tensor,
+ image_positional_embeddings: tf.Tensor,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ ) -> tuple | TFBaseModelOutput:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ all_attentions = ()
+
+ if image_embeddings is None:
+ raise ValueError("You have to specify an image_embedding")
+
+ image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None]
+ image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None]
+
+ # Prepare queries
+ queries = point_embeddings
+ keys = image_embeddings
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ queries, keys, attention_outputs = layer(
+ queries=queries,
+ keys=keys,
+ query_point_embedding=point_embeddings,
+ key_point_embedding=image_positional_embeddings,
+ output_attentions=output_attentions,
+ )
+
+ if output_attentions:
+ all_attentions = all_attentions + (attention_outputs,)
+
+ # Apply the final attention layer from the points to the image
+ query = queries + point_embeddings
+ key = keys + image_positional_embeddings
+
+ attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
+
+ queries = queries + attn_out
+ queries = self.layer_norm_final_attn(queries)
+ return queries, keys, all_attentions
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "final_attn_token_to_image", None) is not None:
+ with tf.name_scope(self.final_attn_token_to_image.name):
+ self.final_attn_token_to_image.build(None)
+ if getattr(self, "layer_norm_final_attn", None) is not None:
+ with tf.name_scope(self.layer_norm_final_attn.name):
+ self.layer_norm_final_attn.build([None, None, None, self.config.hidden_size])
+ for layer in self.layers:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+class TFSamFeedForward(keras.layers.Layer):
+ def __init__(
+ self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.num_layers = num_layers
+ self.activation = keras.layers.ReLU()
+ self.proj_in = keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in")
+ self.proj_out = keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out")
+ self.layers = [
+ keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}")
+ for i in range(num_layers - 2)
+ ]
+ self.sigmoid_output = sigmoid_output
+ self.hidden_dim = hidden_dim
+ self.input_dim = input_dim
+
+ def call(self, hidden_states):
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ for layer in self.layers:
+ hidden_states = self.activation(layer(hidden_states))
+
+ hidden_states = self.proj_out(hidden_states)
+ if self.sigmoid_output:
+ hidden_states = tf.sigmoid(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "proj_in", None) is not None:
+ with tf.name_scope(self.proj_in.name):
+ self.proj_in.build([None, None, self.input_dim])
+ if getattr(self, "proj_out", None) is not None:
+ with tf.name_scope(self.proj_out.name):
+ self.proj_out.build([None, None, self.hidden_dim])
+ if getattr(self, "layers", None) is not None:
+ for layer in self.layers:
+ with tf.name_scope(layer.name):
+ layer.build([None, None, self.hidden_dim])
+
+
+class TFSamMaskDecoder(keras.layers.Layer):
+ def __init__(self, config: SamMaskDecoderConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.hidden_size = config.hidden_size
+
+ self.num_multimask_outputs = config.num_multimask_outputs
+ self.num_mask_tokens = config.num_multimask_outputs + 1
+
+ self.transformer = TFSamTwoWayTransformer(config, name="transformer")
+
+ self.upscale_conv1 = keras.layers.Conv2DTranspose(
+ self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first"
+ )
+ self.upscale_conv2 = keras.layers.Conv2DTranspose(
+ self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first"
+ )
+ self.upscale_layer_norm = TFSamLayerNorm(
+ self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm"
+ )
+ self.activation = tf.nn.gelu
+
+ mlps_list = []
+ for i in range(self.num_mask_tokens):
+ mlps_list += [
+ TFSamFeedForward(
+ self.hidden_size,
+ self.hidden_size,
+ self.hidden_size // 8,
+ 3,
+ name=f"output_hypernetworks_mlps_._{i}",
+ )
+ ]
+ self.output_hypernetworks_mlps = mlps_list
+
+ self.iou_prediction_head = TFSamFeedForward(
+ self.hidden_size,
+ config.iou_head_hidden_dim,
+ self.num_mask_tokens,
+ config.iou_head_depth,
+ name="iou_prediction_head",
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True)
+ self.mask_tokens = self.add_weight(
+ shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True
+ )
+
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+ if getattr(self, "upscale_conv1", None) is not None:
+ with tf.name_scope(self.upscale_conv1.name):
+ self.upscale_conv1.build([None, self.hidden_size, None, None])
+ if getattr(self, "upscale_conv2", None) is not None:
+ with tf.name_scope(self.upscale_conv2.name):
+ self.upscale_conv2.build([None, self.hidden_size // 4, None, None])
+ if getattr(self, "upscale_layer_norm", None) is not None:
+ with tf.name_scope(self.upscale_layer_norm.name):
+ self.upscale_layer_norm.build(None)
+ if getattr(self, "iou_prediction_head", None) is not None:
+ with tf.name_scope(self.iou_prediction_head.name):
+ self.iou_prediction_head.build(None)
+ for mlp in self.output_hypernetworks_mlps:
+ with tf.name_scope(mlp.name):
+ mlp.build(None)
+
+ def call(
+ self,
+ image_embeddings: tf.Tensor,
+ image_positional_embeddings: tf.Tensor,
+ sparse_prompt_embeddings: tf.Tensor,
+ dense_prompt_embeddings: tf.Tensor,
+ multimask_output: bool,
+ output_attentions: bool | None = None,
+ ) -> tuple[tf.Tensor, tf.Tensor]:
+ batch_size, num_channels, height, width = shape_list(image_embeddings)
+ point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1])
+
+ output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0) # Should be (1, 32) + (4, 32) = (5, 32)
+ output_tokens = tf.tile(
+ output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1]
+ ) # Should be (batch_size, point_size, 5, 32)
+
+ # Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only
+ # happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced
+ # it with an explicit shape check to avoid data-dependent control flow which breaks XLA.
+ if shape_list(sparse_prompt_embeddings)[1] != 0:
+ tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2)
+ else:
+ tokens = output_tokens
+ point_embeddings = tf.cast(tokens, self.iou_token.dtype)
+
+ image_embeddings = image_embeddings + dense_prompt_embeddings
+ image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0)
+ image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0)
+
+ point_embedding, image_embeddings, attentions = self.transformer(
+ point_embeddings=point_embeddings,
+ image_embeddings=image_embeddings,
+ image_positional_embeddings=image_positional_embeddings,
+ output_attentions=output_attentions,
+ )
+ iou_token_out = point_embedding[:, :, 0, :]
+ mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
+
+ image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2))
+ image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width])
+
+ upscaled_embedding = self.upscale_conv1(image_embeddings)
+ upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
+ upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
+
+ hyper_in_list = []
+ for i in range(self.num_mask_tokens):
+ current_mlp = self.output_hypernetworks_mlps[i]
+ hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
+ hyper_in = tf.stack(hyper_in_list, axis=2)
+
+ _, num_channels, height, width = shape_list(upscaled_embedding)
+ upscaled_embedding = tf.reshape(
+ upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width]
+ )
+ masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width])
+
+ iou_pred = self.iou_prediction_head(iou_token_out)
+
+ if multimask_output:
+ mask_slice = slice(1, None)
+ else:
+ mask_slice = slice(0, 1)
+ masks = masks[:, :, mask_slice, :, :]
+ iou_pred = iou_pred[:, :, mask_slice]
+
+ outputs = (masks, iou_pred)
+
+ if output_attentions:
+ outputs = outputs + (attentions,)
+ else:
+ outputs = outputs + (None,)
+
+ return outputs
+
+
+class TFSamPositionalEmbedding(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ self.scale = config.hidden_size // 2
+ self.config = config
+
+ def build(self, input_shape):
+ # TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized?
+ self.positional_embedding = self.add_weight(
+ name="positional_embedding",
+ shape=(2, self.config.num_pos_feats),
+ initializer=keras.initializers.RandomNormal(mean=0.0, stddev=self.scale),
+ trainable=False,
+ )
+ super().build(input_shape)
+
+ def call(self, input_coords, input_shape=None):
+ """Positionally encode points that are normalized to [0,1]."""
+ coordinates = tf.identity(input_coords)
+
+ if input_shape is not None:
+ coordinates = tf.stack(
+ [
+ tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1],
+ tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0],
+ ],
+ axis=-1,
+ )
+
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coordinates = 2 * coordinates - 1
+ coordinates = tf.cast(coordinates, self.positional_embedding.dtype)
+ coordinates = tf.matmul(coordinates, self.positional_embedding)
+ coordinates = 2 * np.pi * coordinates
+ # outputs d_1 x ... x d_n x channel shape
+ return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1)
+
+
+class TFSamMaskEmbedding(keras.layers.Layer):
+ def __init__(self, config: SamPromptEncoderConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.mask_input_channels = config.mask_input_channels // 4
+ self.activation = ACT2FN[config.hidden_act]
+ self.conv1 = keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1")
+ self.conv2 = keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2")
+ self.conv3 = keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3")
+ self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1")
+ self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2")
+ self.config = config
+
+ def call(self, masks):
+ masks = tf.transpose(masks, perm=(0, 2, 3, 1)) # Convert to channels-last
+ hidden_states = self.conv1(masks)
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ dense_embeddings = self.conv3(hidden_states)
+ dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2)) # Convert back to channels-first
+ return dense_embeddings
+
+ def build(self, input_shape=None):
+ # This class needs an explicit build method because it isn't called with the standard dummy inputs
+ if self.built:
+ return
+ self.built = True
+ with tf.name_scope("conv1"):
+ self.conv1.build([None, None, None, 1])
+ with tf.name_scope("conv2"):
+ self.conv2.build([None, None, None, self.mask_input_channels])
+ with tf.name_scope("conv3"):
+ self.conv3.build([None, None, None, self.mask_input_channels * 4])
+ with tf.name_scope("layer_norm1"):
+ self.layer_norm1.build([None, None, None, self.mask_input_channels])
+ with tf.name_scope("layer_norm2"):
+ self.layer_norm2.build([None, None, None, self.mask_input_channels * 4])
+
+
+class TFSamPromptEncoder(keras.layers.Layer):
+ def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs):
+ super().__init__(**kwargs)
+ self.shared_embedding = shared_patch_embedding
+ self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed")
+ self.no_mask_embed = None
+
+ self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
+ self.input_image_size = config.image_size
+
+ self.point_embed = []
+ self.hidden_size = config.hidden_size
+ self.not_a_point_embed = None
+ self.config = config
+
+ def build(self, input_shape=None):
+ self.no_mask_embed = self.add_weight(
+ name="no_mask_embed.weight",
+ shape=(1, self.hidden_size),
+ initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
+ trainable=True,
+ )
+ self.point_embed = [
+ self.add_weight(
+ name=f"point_embed_._{i}.weight",
+ shape=(1, self.hidden_size),
+ initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
+ trainable=True,
+ )
+ for i in range(self.config.num_point_embeddings)
+ ]
+ self.not_a_point_embed = self.add_weight(
+ name="not_a_point_embed.weight",
+ shape=(1, self.hidden_size),
+ initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
+ trainable=True,
+ )
+ with tf.name_scope("mask_embed"):
+ # We must explicitly build the mask embed because it isn't touched by the standard dummy inputs
+ self.mask_embed.build(
+ (None, self.config.mask_input_channels, self.config.image_size, self.config.image_size)
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "mask_embed", None) is not None:
+ with tf.name_scope(self.mask_embed.name):
+ self.mask_embed.build(None)
+
+ def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1])
+ target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1)
+ padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
+ padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
+ points = tf.concat([points, padding_point], axis=2)
+ labels = tf.concat([labels, padding_label], axis=2)
+ input_shape = (self.input_image_size, self.input_image_size)
+ point_embedding = self.shared_embedding(points, input_shape)
+
+ point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding)
+
+ point_embedding = tf.where(
+ labels[..., None] != -10,
+ point_embedding,
+ tf.zeros_like(point_embedding),
+ )
+ point_embedding = tf.where(
+ (labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding
+ )
+ point_embedding = tf.where(
+ (labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding
+ )
+ return point_embedding
+
+ def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
+ """Embeds box prompts."""
+ boxes = boxes + 0.5 # Shift to center of pixel
+ batch_size, nb_boxes = shape_list(boxes)[:2]
+ coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
+ input_shape = (self.input_image_size, self.input_image_size)
+ corner_embedding = self.shared_embedding(coords, input_shape)
+ corner_embedding += tf.where(
+ tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0,
+ self.point_embed[2][0],
+ self.point_embed[3][0],
+ )
+ return corner_embedding
+
+ def call(
+ self,
+ batch_size: int | None,
+ input_points: tuple[tf.Tensor, tf.Tensor] | None,
+ input_labels: tf.Tensor | None,
+ input_boxes: tf.Tensor | None,
+ input_masks: tf.Tensor | None,
+ ) -> tuple[tf.Tensor, tf.Tensor]:
+ """
+ Embeds different types of prompts, returning both sparse and dense embeddings.
+
+ Args:
+ points (`tf.Tensor`, *optional*):
+ point coordinates and labels to embed.
+ boxes (`tf.Tensor`, *optional*):
+ boxes to embed
+ masks (`tf.Tensor`, *optional*):
+ masks to embed
+ """
+ sparse_embeddings = None
+ if input_points is not None:
+ batch_size, point_batch_size = shape_list(input_points)[:2]
+ if input_labels is None:
+ raise ValueError("If points are provided, labels must also be provided.")
+ point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
+ sparse_embeddings = tf.zeros(
+ (batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype
+ )
+ sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)
+ if input_boxes is not None:
+ batch_size = shape_list(input_boxes)[0]
+ box_embeddings = self._embed_boxes(input_boxes)
+ if sparse_embeddings is None:
+ sparse_embeddings = box_embeddings
+ else:
+ sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2)
+ if input_masks is not None:
+ dense_embeddings = self.mask_embed(input_masks)
+ else:
+ dense_embeddings = self.no_mask_embed[0]
+ dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1))
+ dense_embeddings = tf.tile(
+ dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1])
+ )
+ if sparse_embeddings is None:
+ sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype)
+
+ return sparse_embeddings, dense_embeddings
+
+
+class TFSamVisionAttention(keras.layers.Layer):
+ """Multi-head Attention block with relative position embeddings."""
+
+ def __init__(self, config, window_size, **kwargs):
+ super().__init__(**kwargs)
+ input_size = (
+ (config.image_size // config.patch_size, config.image_size // config.patch_size)
+ if window_size == 0
+ else (window_size, window_size)
+ )
+ self.input_size = input_size
+
+ self.num_attention_heads = config.num_attention_heads
+ head_dim = config.hidden_size // config.num_attention_heads
+ self.head_dim = head_dim
+ self.scale = head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.qkv = keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv")
+ self.proj = keras.layers.Dense(config.hidden_size, name="proj")
+
+ self.use_rel_pos = config.use_rel_pos
+ if self.use_rel_pos:
+ if input_size is None:
+ raise ValueError("Input size must be provided if using relative positional encoding.")
+ self.config = config
+
+ def build(self, input_shape=None):
+ if self.input_size is not None:
+ # initialize relative positional embeddings
+ self.rel_pos_h = self.add_weight(
+ shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h"
+ )
+ self.rel_pos_w = self.add_weight(
+ shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w"
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "qkv", None) is not None:
+ with tf.name_scope(self.qkv.name):
+ self.qkv.build([None, None, self.config.hidden_size])
+ if getattr(self, "proj", None) is not None:
+ with tf.name_scope(self.proj.name):
+ self.proj.build([None, None, self.config.hidden_size])
+
+ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor:
+ """
+ Get relative positional embeddings according to the relative positions of
+ query and key sizes.
+
+ Args:
+ q_size (int):
+ size of the query.
+ k_size (int):
+ size of key k.
+ rel_pos (`tf.Tensor`):
+ relative position embeddings (L, channel).
+
+ Returns:
+ Extracted positional embeddings according to relative positions.
+ """
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos if needed.
+ if rel_pos.shape[0] != max_rel_dist:
+ # Interpolate rel pos.
+ rel_pos_resized = tf.image.resize(
+ tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)),
+ size=(max_rel_dist, rel_pos.shape[1]),
+ method="bilinear",
+ )
+ rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist))
+ else:
+ rel_pos_resized = rel_pos
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0)
+ k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+ return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))
+
+ def get_decomposed_rel_pos(
+ self,
+ query: tf.Tensor,
+ rel_pos_h: tf.Tensor,
+ rel_pos_w: tf.Tensor,
+ q_size: tuple[int, int],
+ k_size: tuple[int, int],
+ ) -> tf.Tensor:
+ """
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
+
+ Args:
+ query (`tf.Tensor`):
+ query q in the attention layer with shape (batch_size, query_height * query_width, channel).
+ rel_pos_h (`tf.Tensor`):
+ relative position embeddings (Lh, channel) for height axis.
+ rel_pos_w (`tf.Tensor`):
+ relative position embeddings (Lw, channel) for width axis.
+ q_size (tuple):
+ spatial sequence size of query q with (query_height, query_width).
+ k_size (tuple):
+ spatial sequence size of key k with (key_height, key_width).
+
+ Returns:
+ decomposed_rel_pos (`torch.Tensor`):
+ decomposed relative position embeddings.
+ """
+ query_height, query_width = q_size
+ key_height, key_width = k_size
+ relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
+ relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
+
+ batch_size, _, dim = shape_list(query)
+ reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim))
+ rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
+ rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
+
+ rel_h = tf.expand_dims(rel_h, axis=-1)
+ rel_w = tf.expand_dims(rel_w, axis=-2)
+ decomposed_rel_pos = rel_h + rel_w
+
+ return decomposed_rel_pos
+
+ def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor:
+ batch_size, height, width, _ = shape_list(hidden_states)
+ # qkv with shape (3, batch_size, nHead, height * width, channel)
+ qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1))
+ qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4))
+ # q, k, v with shape (batch_size * nHead, height * width, channel)
+ query, key, value = tf.unstack(
+ tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0
+ )
+ attn_weights = tf.matmul(query * self.scale, key, transpose_b=True)
+
+ if self.use_rel_pos:
+ decomposed_rel_pos = self.get_decomposed_rel_pos(
+ query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
+ )
+ decomposed_rel_pos = tf.reshape(decomposed_rel_pos, shape_list(attn_weights))
+ attn_weights = attn_weights + decomposed_rel_pos
+
+ attn_weights = tf.nn.softmax(attn_weights, axis=-1)
+
+ if training:
+ attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout)
+ else:
+ attn_probs = attn_weights
+
+ attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1))
+ attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4))
+ attn_output = tf.reshape(attn_output, (batch_size, height, width, self.config.hidden_size))
+
+ attn_output = self.proj(attn_output)
+
+ if output_attentions:
+ outputs = (attn_output, attn_weights)
+ else:
+ outputs = (attn_output, None)
+
+ return outputs
+
+
+class TFSamVisionLayer(keras.layers.Layer):
+ def __init__(self, config, window_size, **kwargs):
+ super().__init__(**kwargs)
+ self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
+ self.attn = TFSamVisionAttention(config, window_size, name="attn")
+ self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
+ self.mlp = TFSamMLPBlock(config, name="mlp")
+ self.window_size = window_size
+ self.config = config
+
+ def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> tuple[tf.Tensor, tuple[int, int]]:
+ batch_size, height, width, channel = shape_list(hidden_states)
+
+ pad_h = (window_size - height % window_size) % window_size
+ pad_w = (window_size - width % window_size) % window_size
+ if pad_h > 0 or pad_w > 0:
+ hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]])
+ pad_height, pad_width = height + pad_h, width + pad_w
+
+ hidden_states = tf.reshape(
+ hidden_states,
+ [batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel],
+ )
+ windows = tf.reshape(
+ tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel]
+ )
+ return windows, (pad_height, pad_width)
+
+ def window_unpartition(
+ self, windows: tf.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int]
+ ) -> tf.Tensor:
+ pad_height, pad_width = padding_shape
+ height, width = original_shape
+ batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size)
+ hidden_states = tf.reshape(
+ windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1]
+ )
+ hidden_states = tf.reshape(
+ tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1]
+ )
+
+ if pad_height > height or pad_width > width:
+ hidden_states = hidden_states[:, :height, :width, :]
+ return hidden_states
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ output_attentions: bool | None = False,
+ training: bool | None = False,
+ ) -> tuple[tf.Tensor]:
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ if self.window_size > 0:
+ height, width = hidden_states.shape[1], hidden_states.shape[2]
+ hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
+
+ hidden_states, attn_weights = self.attn(
+ hidden_states=hidden_states,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ if self.window_size > 0:
+ hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
+
+ hidden_states = residual + hidden_states
+ layernorm_output = self.layer_norm2(hidden_states)
+ hidden_states = hidden_states + self.mlp(layernorm_output)
+
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layer_norm1", None) is not None:
+ with tf.name_scope(self.layer_norm1.name):
+ self.layer_norm1.build([None, None, None, self.config.hidden_size])
+ if getattr(self, "attn", None) is not None:
+ with tf.name_scope(self.attn.name):
+ self.attn.build(None)
+ if getattr(self, "layer_norm2", None) is not None:
+ with tf.name_scope(self.layer_norm2.name):
+ self.layer_norm2.build([None, None, None, self.config.hidden_size])
+ if getattr(self, "mlp", None) is not None:
+ with tf.name_scope(self.mlp.name):
+ self.mlp.build(None)
+
+
+class TFSamVisionNeck(keras.layers.Layer):
+ def __init__(self, config: SamVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+
+ self.conv1 = keras.layers.Conv2D(
+ config.output_channels,
+ kernel_size=1,
+ use_bias=False,
+ name="conv1",
+ )
+ self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1")
+ self.conv2 = keras.layers.Conv2D(
+ config.output_channels,
+ kernel_size=3,
+ padding="same",
+ use_bias=False,
+ name="conv2",
+ )
+ self.layer_norm2 = TFSamLayerNorm(config.output_channels, name="layer_norm2")
+
+ def call(self, hidden_states):
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = self.layer_norm1(hidden_states)
+
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2])
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "conv1", None) is not None:
+ with tf.name_scope(self.conv1.name):
+ self.conv1.build([None, None, None, self.config.hidden_size])
+ if getattr(self, "layer_norm1", None) is not None:
+ with tf.name_scope(self.layer_norm1.name):
+ self.layer_norm1.build(None)
+ if getattr(self, "conv2", None) is not None:
+ with tf.name_scope(self.conv2.name):
+ self.conv2.build([None, None, None, self.config.output_channels])
+ if getattr(self, "layer_norm2", None) is not None:
+ with tf.name_scope(self.layer_norm2.name):
+ self.layer_norm2.build(None)
+
+
+class TFSamVisionEncoder(keras.layers.Layer):
+ def __init__(self, config: SamVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.image_size = config.image_size
+
+ self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed")
+
+ self.pos_embed = None
+
+ self.layers = []
+ for i in range(config.num_hidden_layers):
+ layer = TFSamVisionLayer(
+ config,
+ window_size=config.window_size if i not in config.global_attn_indexes else 0,
+ name=f"layers_._{i}",
+ )
+ self.layers.append(layer)
+
+ self.neck = TFSamVisionNeck(config, name="neck")
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if self.config.use_abs_pos:
+ # Initialize absolute positional embedding with pretrain image size.
+ self.pos_embed = self.add_weight(
+ shape=[
+ 1,
+ self.config.image_size // self.config.patch_size,
+ self.config.image_size // self.config.patch_size,
+ self.config.hidden_size,
+ ],
+ initializer="zeros",
+ trainable=True,
+ name="pos_embed",
+ )
+
+ if getattr(self, "patch_embed", None) is not None:
+ with tf.name_scope(self.patch_embed.name):
+ self.patch_embed.build(None)
+ if getattr(self, "neck", None) is not None:
+ with tf.name_scope(self.neck.name):
+ self.neck.build(None)
+ for layer in self.layers:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+ def get_input_embeddings(self):
+ return self.patch_embed
+
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> tuple | TFSamVisionEncoderOutput:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.patch_embed(pixel_values)
+ if self.pos_embed is not None:
+ hidden_states = hidden_states + self.pos_embed
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states = self.neck(hidden_states)
+
+ if not return_dict:
+ outputs = (hidden_states,)
+ if output_hidden_states:
+ outputs = outputs + (all_hidden_states,)
+ if output_attentions:
+ outputs = outputs + (all_self_attentions,)
+ return outputs
+
+ return TFSamVisionEncoderOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class TFSamPreTrainedModel(TFPreTrainedModel):
+ config_class = SamConfig
+ base_model_prefix = "sam"
+ main_input_name = "pixel_values"
+
+
+SAM_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a TensorFlow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
+ subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to
+ general usage and behavior.
+
+ Parameters:
+ config ([`SamConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+SAM_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
+ details.
+ input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`):
+ Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
+ better results. The points can be obtained by passing a list of list of list to the processor that will
+ create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the second
+ dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per
+ input point), the third dimension is the number of points per segmentation mask (it is possible to pass
+ multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
+ coordinates of the point. If a different number of points is passed either for each image, or for each
+ mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
+ computation of the embedding will be skipped for these points using the labels.
+ input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points)`):
+ Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
+ official implementation, there are 3 types of labels
+
+ - `1`: the point is a point that contains the object of interest
+ - `0`: the point is a point that does not contain the object of interest
+ - `-1`: the point corresponds to the background
+
+ We added the label:
+
+ - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
+
+ The padding labels should be automatically done by the processor.
+ input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`):
+ Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
+ much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
+ that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch size,
+ the number of boxes per image and the coordinates of the top left and bottom right point of the box. In the
+ order (`x1`, `y1`, `x2`, `y2`):
+
+ - `x1`: the x coordinate of the top left point of the input box
+ - `y1`: the y coordinate of the top left point of the input box
+ - `x2`: the x coordinate of the bottom right point of the input box
+ - `y2`: the y coordinate of the bottom right point of the input box
+
+ input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
+ SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
+ generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
+ manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
+
+ image_embeddings (`tf.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`):
+ Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
+ efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
+ method, and then feed them to the `call` method instead of feeding the `pixel_values`.
+ multimask_output (`bool`, *optional*):
+ In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
+ bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
+ "best" mask, by specifying `multimask_output=False`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+SAM_VISION_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
+ details.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ """The vision model from Sam without any head or projection on top.""",
+ SAM_START_DOCSTRING,
+)
+class TFSamVisionModel(TFSamPreTrainedModel):
+ config_class = SamVisionConfig
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: SamVisionConfig, **kwargs):
+ super().__init__(config, **kwargs)
+ self.vision_encoder = TFSamVisionEncoder(config, name="vision_encoder")
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "vision_encoder", None) is not None:
+ with tf.name_scope(self.vision_encoder.name):
+ self.vision_encoder.build(None)
+
+ def get_input_embeddings(self):
+ return self.vision_encoder.patch_embed
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFSamVisionEncoderOutput, config_class=SamVisionConfig)
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ **kwargs,
+ ) -> TFSamVisionEncoderOutput | tuple[tf.Tensor]:
+ r"""
+ Returns:
+
+ """
+ return self.vision_encoder(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+
+@add_start_docstrings(
+ "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
+ " optional 2D location and bounding boxes.",
+ SAM_START_DOCSTRING,
+)
+class TFSamModel(TFSamPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
+
+ def __init__(self, config, **kwargs):
+ super().__init__(config, **kwargs)
+ self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding")
+
+ self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder")
+ self.prompt_encoder = TFSamPromptEncoder(
+ config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder"
+ )
+ self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder")
+ self.config = config
+
+ def get_input_embeddings(self):
+ return self.vision_encoder.get_input_embeddings()
+
+ def get_image_wide_positional_embeddings(self):
+ size = self.config.prompt_encoder_config.image_embedding_size
+ grid = tf.ones((size, size))
+ y_embed = tf.math.cumsum(grid, axis=0) - 0.5
+ x_embed = tf.math.cumsum(grid, axis=1) - 0.5
+ y_embed = y_embed / size
+ x_embed = x_embed / size
+
+ positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1))
+ return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0) # channel x height x width
+
+ def get_image_embeddings(
+ self,
+ pixel_values,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ ):
+ r"""
+ Returns the image embeddings by passing the pixel values through the vision encoder.
+
+ Args:
+ pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ Input pixel values
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple.
+
+ """
+ vision_output = self.vision_encoder(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ image_embeddings = vision_output[0]
+ return image_embeddings
+
+ def get_prompt_embeddings(
+ self,
+ input_points: tf.Tensor | None = None,
+ input_labels: tf.Tensor | None = None,
+ input_boxes: tf.Tensor | None = None,
+ input_masks: tf.Tensor | None = None,
+ ):
+ r"""
+ Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
+
+ Args:
+ input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
+ Optional input points for the prompt encoder. The padding of the point is automatically done by the
+ processor. `point_batch_size` refers to the number of masks that we want the model to predict per
+ point. The model will output `point_batch_size` times 3 masks in total.
+ input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
+ Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
+ processor, or can be fed by the user.
+ input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`):
+ Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
+ processor. users can also pass manually the input boxes.
+ input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
+ Optional input masks for the prompt encoder.
+ """
+ prompt_output = self.prompt_encoder(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+ return prompt_output
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ input_points: tf.Tensor | None = None,
+ input_labels: tf.Tensor | None = None,
+ input_boxes: tf.Tensor | None = None,
+ input_masks: tf.Tensor | None = None,
+ image_embeddings: tf.Tensor | None = None,
+ multimask_output: bool = True,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ **kwargs,
+ ) -> TFSamImageSegmentationOutput | tuple[tf.Tensor]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None and image_embeddings is None:
+ raise ValueError("Either pixel_values or image_embeddings must be provided.")
+
+ if pixel_values is not None and image_embeddings is not None:
+ raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
+
+ if input_points is not None and len(input_points.shape) != 4:
+ raise ValueError(
+ "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
+ f" got {input_points.shape}.",
+ )
+ if input_boxes is not None and len(input_boxes.shape) != 3:
+ raise ValueError(
+ "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
+ f" got {input_boxes.shape}.",
+ )
+ if input_points is not None and input_boxes is not None:
+ point_batch_size = shape_list(input_points)[1]
+ box_batch_size = shape_list(input_boxes)[1]
+ if point_batch_size != box_batch_size:
+ raise ValueError(
+ f"You should provide as many bounding boxes as input points per box. Got {point_batch_size} and {box_batch_size}."
+ )
+ if pixel_values is not None:
+ # Ensures that later checks pass even with an all-None shape from the serving signature
+ pixel_values = tf.ensure_shape(
+ pixel_values,
+ [
+ None,
+ self.config.vision_config.num_channels,
+ self.config.vision_config.image_size,
+ self.config.vision_config.image_size,
+ ],
+ )
+ image_positional_embeddings = self.get_image_wide_positional_embeddings()
+ # repeat with batch size
+ batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0]
+ image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0)
+
+ vision_attentions = None
+ vision_hidden_states = None
+
+ if pixel_values is not None:
+ vision_outputs = self.vision_encoder(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ training=training,
+ )
+ image_embeddings = vision_outputs["last_hidden_state"]
+
+ if output_hidden_states:
+ vision_hidden_states = vision_outputs["hidden_states"]
+ if output_attentions:
+ vision_attentions = vision_outputs["attentions"]
+
+ if input_points is not None and input_labels is None:
+ input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32)
+
+ if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
+ raise ValueError(
+ "The batch size of the image embeddings and the input points must be the same. ",
+ f"Got {image_embeddings.shape[0]} and {input_points.shape[0]} respectively.",
+ " if you want to pass multiple points for the same image, make sure that you passed ",
+ " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
+ " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
+ )
+
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
+ batch_size=shape_list(image_embeddings)[0],
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+
+ low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
+ image_embeddings=image_embeddings,
+ image_positional_embeddings=image_positional_embeddings,
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ output_attentions=output_attentions,
+ )
+
+ if not return_dict:
+ output = (iou_predictions, low_res_masks)
+ if output_hidden_states:
+ output = output + (vision_hidden_states,)
+
+ if output_attentions:
+ output = output + (vision_attentions, mask_decoder_attentions)
+ return output
+
+ return TFSamImageSegmentationOutput(
+ iou_scores=iou_predictions,
+ pred_masks=low_res_masks,
+ vision_hidden_states=vision_hidden_states,
+ vision_attentions=vision_attentions,
+ mask_decoder_attentions=mask_decoder_attentions,
+ )
+
+ def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput:
+ hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None
+ attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None
+
+ return TFSamImageSegmentationOutput(
+ iou_scores=output.iou_scores,
+ pred_masks=output.pred_masks,
+ vision_hidden_states=hs if self.config.output_hidden_states else None,
+ vision_attentions=attns if self.config.output_attentions else None,
+ mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "shared_image_embedding", None) is not None:
+ with tf.name_scope(self.shared_image_embedding.name):
+ self.shared_image_embedding.build(None)
+ if getattr(self, "vision_encoder", None) is not None:
+ with tf.name_scope(self.vision_encoder.name):
+ self.vision_encoder.build(None)
+ if getattr(self, "prompt_encoder", None) is not None:
+ with tf.name_scope(self.prompt_encoder.name):
+ self.prompt_encoder.build(None)
+ if getattr(self, "mask_decoder", None) is not None:
+ with tf.name_scope(self.mask_decoder.name):
+ self.mask_decoder.build(None)
+
+
+__all__ = ["TFSamVisionModel", "TFSamModel", "TFSamPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/processing_sam.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/processing_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3b3728fe273efc92adf57f119bce1eb899534b7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam/processing_sam.py
@@ -0,0 +1,297 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for SAM.
+"""
+
+from copy import deepcopy
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_utils import ImageInput
+from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
+from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput
+from ...utils import is_tf_available, is_torch_available
+from ...video_utils import VideoInput
+
+
+if is_torch_available():
+ import torch
+
+if is_tf_available():
+ import tensorflow as tf
+
+
+class SamImagesKwargs(ImagesKwargs):
+ segmentation_maps: Optional[ImageInput]
+ input_points: Optional[list[list[float]]]
+ input_labels: Optional[list[list[int]]]
+ input_boxes: Optional[list[list[list[float]]]]
+ point_pad_value: Optional[int]
+
+
+class SamProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs: SamImagesKwargs
+ _defaults = {
+ "images_kwargs": {
+ "point_pad_value": -10,
+ }
+ }
+
+
+class SamProcessor(ProcessorMixin):
+ r"""
+ Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
+ single processor.
+
+ [`SamProcessor`] offers all the functionalities of [`SamImageProcessor`]. See the docstring of
+ [`~SamImageProcessor.__call__`] for more information.
+
+ Args:
+ image_processor (`SamImageProcessor`):
+ An instance of [`SamImageProcessor`]. The image processor is a required input.
+ """
+
+ attributes = ["image_processor"]
+ image_processor_class = "SamImageProcessor"
+
+ def __init__(self, image_processor):
+ super().__init__(image_processor)
+ self.target_size = self.image_processor.size["longest_edge"]
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
+ audio: Optional[AudioInput] = None,
+ video: Optional[VideoInput] = None,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D
+ points and bounding boxes for the model if they are provided.
+ """
+ output_kwargs = self._merge_kwargs(
+ SamProcessorKwargs,
+ tokenizer_init_kwargs={},
+ **kwargs,
+ )
+ input_points = output_kwargs["images_kwargs"].pop("input_points", None)
+ input_labels = output_kwargs["images_kwargs"].pop("input_labels", None)
+ input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None)
+ point_pad_value = output_kwargs["images_kwargs"].pop("point_pad_value", None)
+
+ encoding_image_processor = self.image_processor(
+ images,
+ **output_kwargs["images_kwargs"],
+ )
+
+ # pop arguments that are not used in the forward but used nevertheless
+ original_sizes = encoding_image_processor["original_sizes"]
+
+ if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor
+ original_sizes = original_sizes.numpy()
+
+ input_points, input_labels, input_boxes = self._check_and_preprocess_points(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ )
+
+ encoding_image_processor = self._normalize_and_convert(
+ encoding_image_processor,
+ original_sizes,
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ return_tensors=output_kwargs["common_kwargs"].get("return_tensors"),
+ point_pad_value=point_pad_value,
+ )
+
+ return encoding_image_processor
+
+ def _normalize_and_convert(
+ self,
+ encoding_image_processor,
+ original_sizes,
+ input_points=None,
+ input_labels=None,
+ input_boxes=None,
+ return_tensors="pt",
+ point_pad_value=-10,
+ ):
+ if input_points is not None:
+ if len(original_sizes) != len(input_points):
+ input_points = [
+ self._normalize_coordinates(self.target_size, point, original_sizes[0]) for point in input_points
+ ]
+ else:
+ input_points = [
+ self._normalize_coordinates(self.target_size, point, original_size)
+ for point, original_size in zip(input_points, original_sizes)
+ ]
+ # check that all arrays have the same shape
+ if not all(point.shape == input_points[0].shape for point in input_points):
+ if input_labels is not None:
+ input_points, input_labels = self._pad_points_and_labels(
+ input_points, input_labels, point_pad_value
+ )
+
+ input_points = np.array(input_points)
+
+ if input_labels is not None:
+ input_labels = np.array(input_labels)
+
+ if input_boxes is not None:
+ if len(original_sizes) != len(input_boxes):
+ input_boxes = [
+ self._normalize_coordinates(self.target_size, box, original_sizes[0], is_bounding_box=True)
+ for box in input_boxes
+ ]
+ else:
+ input_boxes = [
+ self._normalize_coordinates(self.target_size, box, original_size, is_bounding_box=True)
+ for box, original_size in zip(input_boxes, original_sizes)
+ ]
+ input_boxes = np.array(input_boxes)
+
+ if input_boxes is not None:
+ if return_tensors == "pt":
+ input_boxes = torch.from_numpy(input_boxes)
+ # boxes batch size of 1 by default
+ input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes
+ elif return_tensors == "tf":
+ input_boxes = tf.convert_to_tensor(input_boxes)
+ # boxes batch size of 1 by default
+ input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes
+ encoding_image_processor.update({"input_boxes": input_boxes})
+ if input_points is not None:
+ if return_tensors == "pt":
+ input_points = torch.from_numpy(input_points)
+ # point batch size of 1 by default
+ input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points
+ elif return_tensors == "tf":
+ input_points = tf.convert_to_tensor(input_points)
+ # point batch size of 1 by default
+ input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points
+ encoding_image_processor.update({"input_points": input_points})
+ if input_labels is not None:
+ if return_tensors == "pt":
+ input_labels = torch.from_numpy(input_labels)
+ # point batch size of 1 by default
+ input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels
+ elif return_tensors == "tf":
+ input_labels = tf.convert_to_tensor(input_labels)
+ # point batch size of 1 by default
+ input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels
+ encoding_image_processor.update({"input_labels": input_labels})
+
+ return encoding_image_processor
+
+ def _pad_points_and_labels(self, input_points, input_labels, point_pad_value):
+ r"""
+ The method pads the 2D points and labels to the maximum number of points in the batch.
+ """
+ expected_nb_points = max(point.shape[0] for point in input_points)
+ processed_input_points = []
+ for i, point in enumerate(input_points):
+ if point.shape[0] != expected_nb_points:
+ point = np.concatenate(
+ [point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0
+ )
+ input_labels[i] = np.append(input_labels[i], [point_pad_value])
+ processed_input_points.append(point)
+ input_points = processed_input_points
+ return input_points, input_labels
+
+ def _normalize_coordinates(
+ self, target_size: int, coords: np.ndarray, original_size, is_bounding_box=False
+ ) -> np.ndarray:
+ """
+ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format.
+ """
+ old_h, old_w = original_size
+ new_h, new_w = self.image_processor._get_preprocess_shape(original_size, longest_edge=target_size)
+ coords = deepcopy(coords).astype(float)
+
+ if is_bounding_box:
+ coords = coords.reshape(-1, 2, 2)
+
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+
+ if is_bounding_box:
+ coords = coords.reshape(-1, 4)
+
+ return coords
+
+ def _check_and_preprocess_points(
+ self,
+ input_points=None,
+ input_labels=None,
+ input_boxes=None,
+ ):
+ r"""
+ Check and preprocesses the 2D points, labels and bounding boxes. It checks if the input is valid and if they
+ are, it converts the coordinates of the points and bounding boxes. If a user passes directly a `torch.Tensor`,
+ it is converted to a `numpy.ndarray` and then to a `list`.
+ """
+ if input_points is not None:
+ if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor
+ input_points = input_points.numpy().tolist()
+
+ if not isinstance(input_points, list) or not isinstance(input_points[0], list):
+ raise ValueError("Input points must be a list of list of floating points.")
+ input_points = [np.array(input_point) for input_point in input_points]
+ else:
+ input_points = None
+
+ if input_labels is not None:
+ if hasattr(input_labels, "numpy"):
+ input_labels = input_labels.numpy().tolist()
+
+ if not isinstance(input_labels, list) or not isinstance(input_labels[0], list):
+ raise ValueError("Input labels must be a list of list integers.")
+ input_labels = [np.array(label) for label in input_labels]
+ else:
+ input_labels = None
+
+ if input_boxes is not None:
+ if hasattr(input_boxes, "numpy"):
+ input_boxes = input_boxes.numpy().tolist()
+
+ if (
+ not isinstance(input_boxes, list)
+ or not isinstance(input_boxes[0], list)
+ or not isinstance(input_boxes[0][0], list)
+ ):
+ raise ValueError("Input boxes must be a list of list of list of floating points.")
+ input_boxes = [np.array(box).astype(np.float32) for box in input_boxes]
+ else:
+ input_boxes = None
+
+ return input_points, input_labels, input_boxes
+
+ @property
+ def model_input_names(self):
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(image_processor_input_names + ["original_sizes", "reshaped_input_sizes"])
+
+ def post_process_masks(self, *args, **kwargs):
+ return self.image_processor.post_process_masks(*args, **kwargs)
+
+
+__all__ = ["SamProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38e6621b063af7a706186a1bf810e3f4709dc1b8
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_sam2 import *
+ from .image_processing_sam2_fast import *
+ from .modeling_sam2 import *
+ from .processing_sam2 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/configuration_sam2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/configuration_sam2.py
new file mode 100644
index 0000000000000000000000000000000000000000..e14583181d38abfff63209160ab2a6eb5219a05f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/configuration_sam2.py
@@ -0,0 +1,453 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""SAM2 model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class Sam2HieraDetConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Sam2HieraDetModel`]. It is used to instantiate
+ a HieraDet model as defined in the original sam2 repo according to the specified arguments, defining the model architecture.
+ Instantiating a configuration defaults will yield a similar configuration to that of SAM 2.1 Hiera-tiny
+ [facebook/sam2.1-hiera-tiny](https://huggingface.co/facebook/sam2.1-hiera-tiny) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 96):
+ The hidden dimension of the image encoder.
+ num_attention_heads (`int`, *optional*, defaults to 1):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of channels in the image.
+ image_size (`list[int]`, *optional*, defaults to `[1024, 1024]`):
+ The size of the image.
+ patch_kernel_size (`list[int]`, *optional*, defaults to `[7, 7]`):
+ The kernel size of the patch.
+ patch_stride (`list[int]`, *optional*, defaults to `[4, 4]`):
+ The stride of the patch.
+ patch_padding (`list[int]`, *optional*, defaults to `[3, 3]`):
+ The padding of the patch.
+ query_stride (`list[int]`, *optional*, defaults to `[2, 2]`):
+ The downsample stride between stages.
+ window_positional_embedding_background_size (`list[int]`, *optional*, defaults to `[7, 7]`):
+ The window size per stage when not using global attention.
+ num_query_pool_stages (`int`, *optional*, defaults to 3):
+ The number of query pool stages.
+ blocks_per_stage (`list[int]`, *optional*, defaults to `[1, 2, 7, 2]`):
+ The number of blocks per stage.
+ embed_dim_per_stage (`list[int]`, *optional*, defaults to `[96, 192, 384, 768]`):
+ The embedding dimension per stage.
+ num_attention_heads_per_stage (`list[int]`, *optional*, defaults to `[1, 2, 4, 8]`):
+ The number of attention heads per stage.
+ window_size_per_stage (`list[int]`, *optional*, defaults to `[8, 4, 14, 7]`):
+ The window size per stage.
+ global_attention_blocks (`list[int]`, *optional*, defaults to `[5, 7, 9]`):
+ The blocks where global attention is used.
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ The ratio of the MLP hidden dimension to the embedding dimension.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the neck.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon for the layer normalization.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+ """
+
+ base_config_key = "backbone_config"
+ model_type = "sam2_hiera_det_model"
+
+ def __init__(
+ self,
+ hidden_size=96,
+ num_attention_heads=1,
+ num_channels=3,
+ image_size=None,
+ patch_kernel_size=None,
+ patch_stride=None,
+ patch_padding=None,
+ query_stride=None,
+ window_positional_embedding_background_size=None,
+ num_query_pool_stages=3,
+ blocks_per_stage=None,
+ embed_dim_per_stage=None,
+ num_attention_heads_per_stage=None,
+ window_size_per_stage=None,
+ global_attention_blocks=None,
+ mlp_ratio=4.0,
+ hidden_act="gelu",
+ layer_norm_eps=1e-6,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ image_size = image_size if image_size is not None else [1024, 1024]
+ patch_kernel_size = patch_kernel_size if patch_kernel_size is not None else [7, 7]
+ patch_stride = patch_stride if patch_stride is not None else [4, 4]
+ patch_padding = patch_padding if patch_padding is not None else [3, 3]
+ query_stride = query_stride if query_stride is not None else [2, 2]
+ window_positional_embedding_background_size = (
+ window_positional_embedding_background_size
+ if window_positional_embedding_background_size is not None
+ else [7, 7]
+ )
+ blocks_per_stage = blocks_per_stage if blocks_per_stage is not None else [1, 2, 7, 2]
+ embed_dim_per_stage = embed_dim_per_stage if embed_dim_per_stage is not None else [96, 192, 384, 768]
+ num_attention_heads_per_stage = (
+ num_attention_heads_per_stage if num_attention_heads_per_stage is not None else [1, 2, 4, 8]
+ )
+ window_size_per_stage = window_size_per_stage if window_size_per_stage is not None else [8, 4, 14, 7]
+ global_attention_blocks = global_attention_blocks if global_attention_blocks is not None else [5, 7, 9]
+
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.patch_kernel_size = patch_kernel_size
+ self.patch_stride = patch_stride
+ self.patch_padding = patch_padding
+ self.query_stride = query_stride
+ self.window_positional_embedding_background_size = window_positional_embedding_background_size
+ self.num_query_pool_stages = num_query_pool_stages
+ self.blocks_per_stage = blocks_per_stage
+ self.embed_dim_per_stage = embed_dim_per_stage
+ self.num_attention_heads_per_stage = num_attention_heads_per_stage
+ self.window_size_per_stage = window_size_per_stage
+ self.global_attention_blocks = global_attention_blocks
+ self.mlp_ratio = mlp_ratio
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+
+
+class Sam2VisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Sam2VisionModel`]. It is used to instantiate a SAM
+ vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
+ defaults will yield a similar configuration to that of SAM 2.1 Hiera-tiny
+ [facebook/sam2.1-hiera-tiny](https://huggingface.co/facebook/sam2.1-hiera-tiny) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ backbone_config (`Union[dict, "PretrainedConfig"]`, *optional*):
+ Configuration for the vision backbone. This is used to instantiate the backbone using
+ `AutoModel.from_config`.
+ backbone_channel_list (`List[int]`, *optional*, defaults to `[768, 384, 192, 96]`):
+ The list of channel dimensions for the backbone.
+ backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[256, 256], [128, 128], [64, 64]]`):
+ The spatial sizes of the feature maps from the backbone.
+ fpn_hidden_size (`int`, *optional*, defaults to 256):
+ The hidden dimension of the FPN.
+ fpn_kernel_size (`int`, *optional*, defaults to 1):
+ The kernel size for the convolutions in the neck.
+ fpn_stride (`int`, *optional*, defaults to 1):
+ The stride for the convolutions in the neck.
+ fpn_padding (`int`, *optional*, defaults to 0):
+ The padding for the convolutions in the neck.
+ fpn_top_down_levels (`List[int]`, *optional*, defaults to `[2, 3]`):
+ The levels for the top-down FPN connections.
+ num_feature_levels (`int`, *optional*, defaults to 3):
+ The number of feature levels from the FPN to use.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the neck.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon for the layer normalization.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+ """
+
+ base_config_key = "vision_config"
+ model_type = "sam2_vision_model"
+ sub_configs = {
+ "backbone_config": AutoConfig,
+ }
+
+ def __init__(
+ self,
+ backbone_config=None,
+ backbone_channel_list=None,
+ backbone_feature_sizes=None,
+ fpn_hidden_size=256,
+ fpn_kernel_size=1,
+ fpn_stride=1,
+ fpn_padding=0,
+ fpn_top_down_levels=None,
+ num_feature_levels=3,
+ hidden_act="gelu",
+ layer_norm_eps=1e-6,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ backbone_channel_list = [768, 384, 192, 96] if backbone_channel_list is None else backbone_channel_list
+ backbone_feature_sizes = (
+ [[256, 256], [128, 128], [64, 64]] if backbone_feature_sizes is None else backbone_feature_sizes
+ )
+ fpn_top_down_levels = [2, 3] if fpn_top_down_levels is None else fpn_top_down_levels
+
+ if isinstance(backbone_config, dict):
+ backbone_config["model_type"] = backbone_config.get("model_type", "sam2_hiera_det_model")
+ backbone_config = CONFIG_MAPPING[backbone_config["model_type"]](**backbone_config)
+ elif isinstance(backbone_config, Sam2HieraDetConfig):
+ pass
+ elif backbone_config is None:
+ backbone_config = Sam2HieraDetConfig()
+
+ self.backbone_config = backbone_config
+
+ # Neck
+ self.backbone_channel_list = backbone_channel_list
+ self.backbone_feature_sizes = backbone_feature_sizes
+ self.fpn_hidden_size = fpn_hidden_size
+ self.fpn_kernel_size = fpn_kernel_size
+ self.fpn_stride = fpn_stride
+ self.fpn_padding = fpn_padding
+ self.fpn_top_down_levels = fpn_top_down_levels
+ self.num_feature_levels = num_feature_levels
+
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+
+
+class Sam2PromptEncoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Sam2PromptEncoder`]. The [`Sam2PromptEncoder`]
+ module is used to encode the input 2D points and bounding boxes.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the hidden states.
+ image_size (`int`, *optional*, defaults to 1024):
+ The expected output resolution of the image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ mask_input_channels (`int`, *optional*, defaults to 16):
+ The number of channels to be fed to the `MaskDecoder` module.
+ num_point_embeddings (`int`, *optional*, defaults to 4):
+ The number of point embeddings to be used.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the encoder and pooler.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ scale (`float`, *optional*, defaults to 1):
+ The scale factor for the prompt encoder.
+ """
+
+ base_config_key = "prompt_encoder_config"
+
+ def __init__(
+ self,
+ hidden_size=256,
+ image_size=1024,
+ patch_size=16,
+ mask_input_channels=16,
+ num_point_embeddings=4,
+ hidden_act="gelu",
+ layer_norm_eps=1e-6,
+ scale=1,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.mask_input_channels = mask_input_channels
+ self.num_point_embeddings = num_point_embeddings
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.scale = scale
+
+
+class Sam2MaskDecoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Sam2MaskDecoder`]. It is used to instantiate a SAM2
+ memory encoder according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the hidden states.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the SAM2 mask decoder.
+ mlp_dim (`int`, *optional*, defaults to 2048):
+ The dimension of the MLP in the two-way transformer.
+ num_hidden_layers (`int`, *optional*, defaults to 2):
+ The number of hidden layers in the two-way transformer.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ The number of attention heads in the two-way transformer.
+ attention_downsample_rate (`int`, *optional*, defaults to 2):
+ The downsample rate for the attention layers.
+ num_multimask_outputs (`int`, *optional*, defaults to 3):
+ The number of multimask outputs.
+ iou_head_depth (`int`, *optional*, defaults to 3):
+ The depth of the IoU head.
+ iou_head_hidden_dim (`int`, *optional*, defaults to 256):
+ The hidden dimension of the IoU head.
+ dynamic_multimask_via_stability (`bool`, *optional*, defaults to `True`):
+ Whether to use dynamic multimask via stability.
+ dynamic_multimask_stability_delta (`float`, *optional*, defaults to 0.05):
+ The stability delta for the dynamic multimask.
+ dynamic_multimask_stability_thresh (`float`, *optional*, defaults to 0.98):
+ The stability threshold for the dynamic multimask.
+
+ """
+
+ base_config_key = "mask_decoder_config"
+
+ def __init__(
+ self,
+ hidden_size=256,
+ hidden_act="gelu",
+ mlp_dim=2048,
+ num_hidden_layers=2,
+ num_attention_heads=8,
+ attention_downsample_rate=2,
+ num_multimask_outputs=3,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ dynamic_multimask_via_stability=True,
+ dynamic_multimask_stability_delta=0.05,
+ dynamic_multimask_stability_thresh=0.98,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_multimask_outputs = num_multimask_outputs
+ self.hidden_act = hidden_act
+ self.iou_head_depth = iou_head_depth
+ self.iou_head_hidden_dim = iou_head_hidden_dim
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
+
+ # TwoWayTransformer configuration
+ self.num_hidden_layers = num_hidden_layers
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.mlp_dim = mlp_dim
+ self.attention_downsample_rate = attention_downsample_rate
+
+
+class Sam2Config(PretrainedConfig):
+ r"""
+ [`Sam2Config`] is the configuration class to store the configuration of a [`Sam2Model`]. It is used to instantiate a
+ SAM2 model according to the specified arguments, defining the memory attention, memory encoder, and image encoder
+ configs. Instantiating a configuration defaults will yield a similar configuration to that of the SAM 2.1 Hiera-tiny
+ [facebook/sam2.1-hiera-tiny](https://huggingface.co/facebook/sam2.1-hiera-tiny) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (Union[`dict`, `Sam2VisionConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`Sam2VisionConfig`].
+ prompt_encoder_config (Union[`dict`, `Sam2PromptEncoderConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`Sam2PromptEncoderConfig`].
+ mask_decoder_config (Union[`dict`, `Sam2MaskDecoderConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`Sam2MaskDecoderConfig`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ Standard deviation for parameter initialization.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... Sam2VisionConfig,
+ ... Sam2PromptEncoderConfig,
+ ... Sam2MaskDecoderConfig,
+ ... Sam2Model,
+ ... )
+
+ >>> # Initializing a Sam2Config with `"facebook/sam2.1_hiera_tiny"` style configuration
+ >>> configuration = Sam2config()
+
+ >>> # Initializing a Sam2Model (with random weights) from the `"facebook/sam2.1_hiera_tiny"` style configuration
+ >>> model = Sam2Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+
+ >>> # We can also initialize a Sam2Config from a Sam2VisionConfig, Sam2PromptEncoderConfig, and Sam2MaskDecoderConfig
+
+ >>> # Initializing SAM2 vision encoder, memory attention, and memory encoder configurations
+ >>> vision_config = Sam2VisionConfig()
+ >>> prompt_encoder_config = Sam2PromptEncoderConfig()
+ >>> mask_decoder_config = Sam2MaskDecoderConfig()
+
+ >>> config = Sam2Config(vision_config, prompt_encoder_config, mask_decoder_config)
+ ```"""
+
+ model_type = "sam2"
+ sub_configs = {
+ "vision_config": AutoConfig,
+ "prompt_encoder_config": Sam2PromptEncoderConfig,
+ "mask_decoder_config": Sam2MaskDecoderConfig,
+ }
+
+ def __init__(
+ self,
+ vision_config=None,
+ prompt_encoder_config=None,
+ mask_decoder_config=None,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ vision_config = vision_config if vision_config is not None else {}
+ prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {}
+ mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {}
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "sam2_vision_model")
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ if isinstance(prompt_encoder_config, Sam2PromptEncoderConfig):
+ prompt_encoder_config = prompt_encoder_config.to_dict()
+ if isinstance(mask_decoder_config, Sam2MaskDecoderConfig):
+ mask_decoder_config = mask_decoder_config.to_dict()
+
+ self.vision_config = vision_config
+ self.prompt_encoder_config = Sam2PromptEncoderConfig(**prompt_encoder_config)
+ self.mask_decoder_config = Sam2MaskDecoderConfig(**mask_decoder_config)
+
+ self.initializer_range = initializer_range
+
+
+__all__ = [
+ "Sam2Config",
+ "Sam2HieraDetConfig",
+ "Sam2VisionConfig",
+ "Sam2PromptEncoderConfig",
+ "Sam2MaskDecoderConfig",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/image_processing_sam2_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/image_processing_sam2_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..a773e8ad54d7b640f39a17c0274621e1e4d815ee
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/image_processing_sam2_fast.py
@@ -0,0 +1,730 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/sam2/modular_sam2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_sam2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torchvision.ops.boxes import batched_nms
+
+from ...image_processing_utils import BatchFeature, get_size_dict
+from ...image_processing_utils_fast import BaseImageProcessorFast, DefaultFastImageProcessorKwargs
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+ pil_torch_interpolation_mapping,
+)
+from ...processing_utils import Unpack
+from ...utils import TensorType, auto_docstring
+
+
+class Sam2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ r"""
+ mask_size (`dict[str, int]`, *optional*):
+ The size `{"height": int, "width": int}` to resize the segmentation maps to.
+ """
+
+ mask_size: Optional[dict[str, int]]
+
+
+def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
+ # One mask is always contained inside the other.
+ # Save memory by preventing unnecessary cast to torch.int64
+ intersections = (
+ (masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
+ )
+ unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
+ stability_scores = intersections / unions
+ return stability_scores
+
+
+def _mask_to_rle(input_mask: "torch.Tensor"):
+ """
+ Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
+ """
+ # Put in fortran order and flatten height and width
+ batch_size, height, width = input_mask.shape
+ input_mask = input_mask.permute(0, 2, 1).flatten(1)
+
+ # Compute change indices
+ diff = input_mask[:, 1:] ^ input_mask[:, :-1]
+ change_indices = diff.nonzero()
+
+ # Encode run length
+ out = []
+ for i in range(batch_size):
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
+ if len(cur_idxs) == 0:
+ # No changes => either all 0 or all 1
+ # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
+ if input_mask[i, 0] == 0:
+ out.append({"size": [height, width], "counts": [height * width]})
+ else:
+ out.append({"size": [height, width], "counts": [0, height * width]})
+ continue
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+ counts = [] if input_mask[i, 0] == 0 else [0]
+ counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
+ out.append({"size": [height, width], "counts": counts})
+ return out
+
+
+def _batched_mask_to_box(masks: "torch.Tensor"):
+ """
+ Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
+ corresponds the following required indices:
+ - LEFT: left hand side of the bounding box
+ - TOP: top of the bounding box
+ - RIGHT: right of the bounding box
+ - BOTTOM: bottom of the bounding box
+
+ Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
+ is channel_1 x channel_2 x ... x 4.
+
+ Args:
+ - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
+ """
+ # torch.max below raises an error on empty inputs, just skip in this case
+
+ if torch.numel(masks) == 0:
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+ # Normalize shape to Cxheightxwidth
+ shape = masks.shape
+ height, width = shape[-2:]
+
+ # Get top and bottom edges
+ in_height, _ = torch.max(masks, dim=-1)
+ in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+ in_height_coords = in_height_coords + height * (~in_height)
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+ # Get left and right edges
+ in_width, _ = torch.max(masks, dim=-2)
+ in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
+ in_width_coords = in_width_coords + width * (~in_width)
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+ # If the mask is empty the right edge will be to the left of the left edge.
+ # Replace these boxes with [0, 0, 0, 0]
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+ out = out * (~empty_filter).unsqueeze(-1)
+
+ # Return to original shape
+ out = out.reshape(*shape[:-2], 4)
+ return out
+
+
+def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+
+ left, top, _, _ = crop_box
+ offset = torch.tensor([[left, top, left, top]], device=boxes.device)
+ # Check if boxes has a channel dimension
+ if len(boxes.shape) == 3:
+ offset = offset.unsqueeze(1)
+ boxes = (boxes + offset).float()
+
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+ return torch.any(near_crop_edge, dim=1)
+
+
+def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int):
+ left, top, right, bottom = crop_box
+ if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
+ return masks
+ # Coordinate transform masks
+ pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
+ pad = (left, pad_x - left, top, pad_y - top)
+ return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def _generate_crop_boxes(
+ image,
+ target_size: int, # Is it tuple here?
+ crop_n_layers: int = 0,
+ overlap_ratio: float = 512 / 1500,
+ points_per_crop: Optional[int] = 32,
+ crop_n_points_downscale_factor: Optional[list[int]] = 1,
+) -> tuple[list[list[int]], list[int]]:
+ """
+ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
+
+ Args:
+ image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]):
+ Image to generate crops for.
+ target_size (`int`):
+ Size of the smallest crop.
+ crop_n_layers (`int`, *optional*):
+ If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers
+ to run, where each layer has 2**i_layer number of image crops.
+ overlap_ratio (`int`, *optional*):
+ Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the
+ image length. Later layers with more crops scale down this overlap.
+ points_per_crop (`int`, *optional*):
+ Number of points to sam2ple per crop.
+ crop_n_points_downscale_factor (`int`, *optional*):
+ The number of points-per-side sam2pled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+
+ if isinstance(image, list):
+ raise ValueError("Only one image is allowed for crop generation.")
+ original_size = image.shape[-2:]
+
+ points_grid = []
+ for i in range(crop_n_layers + 1):
+ n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
+ points_grid.append(_build_point_grid(n_points))
+
+ crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
+
+ cropped_images, point_grid_per_crop = _generate_crop_images(
+ crop_boxes, image, points_grid, layer_idxs, target_size, original_size
+ )
+ crop_boxes = torch.tensor(crop_boxes)
+ crop_boxes = crop_boxes.float()
+ points_per_crop = torch.stack(point_grid_per_crop)
+ points_per_crop = points_per_crop.unsqueeze(0).permute(0, 2, 1, 3)
+ cropped_images = torch.stack(cropped_images)
+
+ input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.int64)
+
+ return crop_boxes, points_per_crop, cropped_images, input_labels
+
+
+def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
+ """
+ Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format
+ consists of the following required indices:
+ - X: X coordinate of the top left of the bounding box
+ - Y: Y coordinate of the top left of the bounding box
+ - W: width of the bounding box
+ - H: height of the bounding box
+ """
+ crop_boxes, layer_idxs = [], []
+ im_height, im_width = original_size
+ short_side = min(im_height, im_width)
+
+ # Original image
+ crop_boxes.append([0, 0, im_width, im_height])
+ layer_idxs.append(0)
+ for i_layer in range(crop_n_layers):
+ n_crops_per_side = 2 ** (i_layer + 1)
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+ crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
+ crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
+
+ crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
+ crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
+
+ for left, top in product(crop_box_x0, crop_box_y0):
+ box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
+ crop_boxes.append(box)
+ layer_idxs.append(i_layer + 1)
+
+ return crop_boxes, layer_idxs
+
+
+def _build_point_grid(n_per_side: int) -> torch.Tensor:
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+ offset = 1 / (2 * n_per_side)
+ points_one_side = torch.linspace(offset, 1 - offset, n_per_side)
+ points_x = torch.tile(points_one_side[None, :], (n_per_side, 1))
+ points_y = torch.tile(points_one_side[:, None], (1, n_per_side))
+ points = torch.stack([points_x, points_y], dim=-1).reshape(-1, 2)
+ return points
+
+
+def _generate_crop_images(
+ crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
+):
+ """
+ Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
+ also passed.
+ """
+ cropped_images = []
+ total_points_per_crop = []
+ for i, crop_box in enumerate(crop_boxes):
+ left, top, right, bottom = crop_box
+ cropped_im = image[:, top:bottom, left:right]
+
+ cropped_images.append(cropped_im)
+
+ cropped_im_size = cropped_im.shape[-2:]
+ points_scale = torch.tensor(cropped_im_size).flip(dims=(0,)).unsqueeze(0)
+
+ points = points_grid[layer_idxs[i]] * points_scale
+ normalized_points = _normalize_coordinates(target_size, points, original_size)
+ total_points_per_crop.append(normalized_points)
+
+ return cropped_images, total_points_per_crop
+
+
+def _normalize_coordinates(
+ target_size: int, coords: torch.Tensor, original_size: tuple[int, int], is_bounding_box=False
+) -> torch.Tensor:
+ """
+ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width)
+ format.
+ """
+ old_height, old_width = original_size
+
+ scale = target_size * 1.0 / max(old_height, old_width)
+ new_height, new_width = old_height * scale, old_width * scale
+ new_width = int(new_width + 0.5)
+ new_height = int(new_height + 0.5)
+
+ coords = deepcopy(coords).float()
+
+ if is_bounding_box:
+ coords = coords.reshape(-1, 2, 2)
+
+ coords[..., 0] = coords[..., 0] * (new_width / old_width)
+ coords[..., 1] = coords[..., 1] * (new_height / old_height)
+
+ if is_bounding_box:
+ coords = coords.reshape(-1, 4)
+
+ return coords
+
+
+def _rle_to_mask(rle: dict[str, Any]) -> torch.Tensor:
+ """Compute a binary mask from an uncompressed RLE."""
+ height, width = rle["size"]
+ mask = torch.empty(height * width, dtype=bool)
+ idx = 0
+ parity = False
+ for count in rle["counts"]:
+ mask[idx : idx + count] = parity
+ idx += count
+ parity = not parity
+ mask = mask.reshape(width, height)
+ return mask.transpose(0, 1) # Reshape to original shape
+
+
+def _post_process_for_mask_generation(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
+ """
+ Perform NMS (Non Maximum Suppression) on the outputs.
+
+ Args:
+ rle_masks (`torch.Tensor`):
+ binary masks in the RLE format
+ iou_scores (`torch.Tensor` of shape (nb_masks, 1)):
+ iou_scores predicted by the model
+ mask_boxes (`torch.Tensor`):
+ The bounding boxes corresponding to segmentation masks
+ amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
+ NMS threshold.
+ """
+ keep_by_nms = batched_nms(
+ boxes=mask_boxes.float(),
+ scores=iou_scores,
+ idxs=torch.zeros(mask_boxes.shape[0]),
+ iou_threshold=amg_crops_nms_thresh,
+ )
+
+ iou_scores = iou_scores[keep_by_nms]
+ rle_masks = [rle_masks[i] for i in keep_by_nms]
+ mask_boxes = mask_boxes[keep_by_nms]
+ masks = [_rle_to_mask(rle) for rle in rle_masks]
+
+ return masks, iou_scores, rle_masks, mask_boxes
+
+
+@auto_docstring
+class Sam2ImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_DEFAULT_MEAN
+ image_std = IMAGENET_DEFAULT_STD
+ size = {"height": 1024, "width": 1024}
+ mask_size = {"height": 256, "width": 256}
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+
+ valid_kwargs = Sam2FastImageProcessorKwargs
+
+ # modular artefacts
+ do_pad = None
+ pad_size = None
+ mask_pad_size = None
+
+ def __init__(self, **kwargs: Unpack[Sam2FastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ def _further_process_kwargs(
+ self,
+ size: Optional[SizeDict] = None,
+ mask_size: Optional[SizeDict] = None,
+ default_to_square: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ **kwargs,
+ ) -> dict:
+ """
+ Update kwargs that need further processing before being validated
+ Can be overridden by subclasses to customize the processing of kwargs.
+ """
+ if kwargs is None:
+ kwargs = {}
+ if size is not None:
+ size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
+ if mask_size is not None:
+ mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size"))
+ if isinstance(image_mean, list):
+ image_mean = tuple(image_mean)
+ if isinstance(image_std, list):
+ image_std = tuple(image_std)
+ if data_format is None:
+ data_format = ChannelDimension.FIRST
+
+ kwargs["size"] = size
+ kwargs["mask_size"] = mask_size
+ kwargs["image_mean"] = image_mean
+ kwargs["image_std"] = image_std
+ kwargs["data_format"] = data_format
+
+ # torch resize uses interpolation instead of resample
+ # Check if resample is an int before checking if it's an instance of PILImageResampling
+ # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
+ # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
+ resample = kwargs.pop("resample")
+ kwargs["interpolation"] = (
+ pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
+ )
+
+ return kwargs
+
+ @auto_docstring
+ def preprocess(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput] = None,
+ **kwargs: Unpack[Sam2FastImageProcessorKwargs],
+ ) -> BatchFeature:
+ r"""
+ segmentation_maps (`ImageInput`, *optional*):
+ The segmentation maps to preprocess.
+ """
+ return super().preprocess(images, segmentation_maps, **kwargs)
+
+ def _preprocess_image_like_inputs(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput],
+ do_convert_rgb: bool,
+ input_data_format: ChannelDimension,
+ device: Optional[Union[str, "torch.device"]] = None,
+ **kwargs: Unpack[Sam2FastImageProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Preprocess image-like inputs.
+ """
+ images = self._prepare_image_like_inputs(
+ images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
+ )
+ original_sizes = [image.shape[-2:] for image in images]
+ images_kwargs = kwargs.copy()
+ pixel_values = self._preprocess(images, **images_kwargs)
+ reshaped_input_sizes = [image.shape[-2:] for image in images]
+ data = {
+ "pixel_values": pixel_values,
+ "original_sizes": original_sizes,
+ "reshaped_input_sizes": reshaped_input_sizes,
+ }
+
+ if segmentation_maps is not None:
+ processed_segmentation_maps = self._prepare_image_like_inputs(
+ images=segmentation_maps,
+ expected_ndims=2,
+ do_convert_rgb=False,
+ input_data_format=ChannelDimension.FIRST,
+ )
+
+ segmentation_maps_kwargs = kwargs.copy()
+ segmentation_maps_kwargs.update(
+ {
+ "do_normalize": False,
+ "do_rescale": False,
+ "interpolation": pil_torch_interpolation_mapping[PILImageResampling.NEAREST],
+ "size": segmentation_maps_kwargs.pop("mask_size"),
+ }
+ )
+ processed_segmentation_maps = self._preprocess(
+ images=processed_segmentation_maps, **segmentation_maps_kwargs
+ )
+ data["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64)
+
+ return BatchFeature(data=data, tensor_type=kwargs["return_tensors"])
+
+ def generate_crop_boxes(
+ self,
+ image: "torch.Tensor",
+ target_size,
+ crop_n_layers: int = 0,
+ overlap_ratio: float = 512 / 1500,
+ points_per_crop: Optional[int] = 32,
+ crop_n_points_downscale_factor: Optional[list[int]] = 1,
+ device: Optional["torch.device"] = None,
+ ):
+ """
+ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
+
+ Args:
+ image (`torch.Tensor`):
+ Input original image
+ target_size (`int`):
+ Target size of the resized image
+ crop_n_layers (`int`, *optional*, defaults to 0):
+ If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
+ each layer has 2**i_layer number of image crops.
+ overlap_ratio (`float`, *optional*, defaults to 512/1500):
+ Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
+ the image length. Later layers with more crops scale down this overlap.
+ points_per_crop (`int`, *optional*, defaults to 32):
+ Number of points to sam2ple from each crop.
+ crop_n_points_downscale_factor (`list[int]`, *optional*, defaults to 1):
+ The number of points-per-side sam2pled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ device (`torch.device`, *optional*, defaults to None):
+ Device to use for the computation. If None, cpu will be used.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ return_tensors (`str`, *optional*, defaults to `pt`):
+ If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
+ """
+ image = self._process_image(image)
+ crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
+ image,
+ target_size,
+ crop_n_layers,
+ overlap_ratio,
+ points_per_crop,
+ crop_n_points_downscale_factor,
+ )
+ if device is None:
+ device = torch.device("cpu")
+ crop_boxes = crop_boxes.to(device)
+ points_per_crop = points_per_crop.to(device)
+ # cropped_images stays as torch.Tensor
+ input_labels = input_labels.to(device)
+
+ return crop_boxes, points_per_crop, cropped_images, input_labels
+
+ def filter_masks(
+ self,
+ masks,
+ iou_scores,
+ original_size,
+ cropped_box_image,
+ pred_iou_thresh=0.88,
+ stability_score_thresh=0.95,
+ mask_threshold=0,
+ stability_score_offset=1,
+ ):
+ """
+ Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
+ that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
+ score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
+ bounding boxes and pad the predicted masks if necessary.
+
+ Args:
+ masks (`torch.Tensor`):
+ Input masks.
+ iou_scores (`torch.Tensor`):
+ List of IoU scores.
+ original_size (`tuple[int,int]`):
+ Size of the original image.
+ cropped_box_image (`torch.Tensor`):
+ The cropped image.
+ pred_iou_thresh (`float`, *optional*, defaults to 0.88):
+ The threshold for the iou scores.
+ stability_score_thresh (`float`, *optional*, defaults to 0.95):
+ The threshold for the stability score.
+ mask_threshold (`float`, *optional*, defaults to 0):
+ The threshold for the predicted masks.
+ stability_score_offset (`float`, *optional*, defaults to 1):
+ The offset for the stability score used in the `_compute_stability_score` method.
+
+ """
+ original_height, original_width = original_size
+ iou_scores = iou_scores.flatten(0, 1)
+ masks = masks.flatten(0, 1)
+
+ if masks.shape[0] != iou_scores.shape[0]:
+ raise ValueError("masks and iou_scores must have the sam2e batch size.")
+
+ if masks.device != iou_scores.device:
+ iou_scores = iou_scores.to(masks.device)
+
+ batch_size = masks.shape[0]
+
+ keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
+
+ if pred_iou_thresh > 0.0:
+ keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
+
+ # compute stability score
+ if stability_score_thresh > 0.0:
+ stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset)
+ keep_mask = keep_mask & (stability_scores > stability_score_thresh)
+
+ scores = iou_scores[keep_mask]
+ masks = masks[keep_mask]
+
+ # binarize masks
+ masks = masks > mask_threshold
+ converted_boxes = _batched_mask_to_box(masks)
+
+ keep_mask = ~_is_box_near_crop_edge(
+ converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
+ )
+
+ scores = scores[keep_mask]
+ masks = masks[keep_mask]
+ converted_boxes = converted_boxes[keep_mask]
+
+ masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
+ # conversion to rle is necessary to run non-maximum suppression
+ masks = _mask_to_rle(masks)
+
+ return masks, scores, converted_boxes
+
+ def post_process_masks(
+ self,
+ masks,
+ original_sizes,
+ mask_threshold=0.0,
+ binarize=True,
+ max_hole_area=0.0,
+ max_sprinkle_area=0.0,
+ apply_non_overlapping_constraints=False,
+ **kwargs,
+ ):
+ """
+ Remove padding and upscale masks to the original image size.
+
+ Args:
+ masks (`Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]]`):
+ Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
+ original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
+ The original sizes of each image before it was resized to the model's expected input shape, in (height,
+ width) format.
+ mask_threshold (`float`, *optional*, defaults to 0.0):
+ Threshold for binarization and post-processing operations.
+ binarize (`bool`, *optional*, defaults to `True`):
+ Whether to binarize the masks.
+ max_hole_area (`float`, *optional*, defaults to 0.0):
+ The maximum area of a hole to fill.
+ max_sprinkle_area (`float`, *optional*, defaults to 0.0):
+ The maximum area of a sprinkle to fill.
+ apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`):
+ Whether to apply non-overlapping constraints to the masks.
+
+ Returns:
+ (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
+ is given by original_size.
+ """
+ if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
+ original_sizes = original_sizes.tolist()
+ # TODO: add connected components kernel for postprocessing
+ output_masks = []
+ for i, original_size in enumerate(original_sizes):
+ if isinstance(masks[i], np.ndarray):
+ masks[i] = torch.from_numpy(masks[i])
+ elif not isinstance(masks[i], torch.Tensor):
+ raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
+ interpolated_mask = F.interpolate(masks[i], original_size, mode="bilinear", align_corners=False)
+ if apply_non_overlapping_constraints:
+ interpolated_mask = self._apply_non_overlapping_constraints(interpolated_mask)
+ if binarize:
+ interpolated_mask = interpolated_mask > mask_threshold
+ output_masks.append(interpolated_mask)
+
+ return output_masks
+
+ def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh):
+ """
+ Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
+
+ Args:
+ all_masks (`torch.Tensor`):
+ List of all predicted segmentation masks
+ all_scores (`torch.Tensor`):
+ List of all predicted iou scores
+ all_boxes (`torch.Tensor`):
+ List of all bounding boxes of the predicted masks
+ crops_nms_thresh (`float`):
+ Threshold for NMS (Non Maximum Suppression) algorithm.
+ """
+ return _post_process_for_mask_generation(all_masks, all_scores, all_boxes, crops_nms_thresh)
+
+ def pad_image(self):
+ raise NotImplementedError("No pad_image for SAM 2.")
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> "torch.Tensor":
+ return super()._preprocess(images, return_tensors=return_tensors, **kwargs).pixel_values
+
+ def _apply_non_overlapping_constraints(self, pred_masks: torch.Tensor) -> torch.Tensor:
+ """
+ Apply non-overlapping constraints to the object scores in pred_masks. Here we
+ keep only the highest scoring object at each spatial location in pred_masks.
+ """
+ batch_size = pred_masks.size(0)
+ if batch_size == 1:
+ return pred_masks
+
+ device = pred_masks.device
+ # "max_obj_inds": object index of the object with the highest score at each location
+ max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
+ # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
+ batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
+ keep = max_obj_inds == batch_obj_inds
+ # suppress overlapping regions' scores below -10.0 so that the foreground regions
+ # don't overlap (here sigmoid(-10.0)=4.5398e-05)
+ pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
+ return pred_masks
+
+
+__all__ = ["Sam2ImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/modeling_sam2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/modeling_sam2.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe42cc39cacf1a29b42f633fd3dcb4a744d8afc7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/modeling_sam2.py
@@ -0,0 +1,1611 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/sam2/modular_sam2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_sam2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from transformers.utils.generic import OutputRecorder
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import compile_compatible_method_lru_cache
+from ...utils import ModelOutput, auto_docstring
+from ...utils.generic import TransformersKwargs, check_model_inputs
+from ..auto import AutoModel
+from .configuration_sam2 import (
+ Sam2Config,
+ Sam2HieraDetConfig,
+ Sam2MaskDecoderConfig,
+ Sam2PromptEncoderConfig,
+ Sam2VisionConfig,
+)
+
+
+@dataclass
+@auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
+class Sam2VisionEncoderOutput(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ fpn_hidden_states (`tuple(torch.FloatTensor)`):
+ Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
+ `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
+ fpn_position_encoding (`tuple(torch.FloatTensor)`):
+ Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
+ `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the
+ model at the output of each stage.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ fpn_hidden_states: Optional[torch.FloatTensor] = None
+ fpn_position_encoding: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(custom_intro="Base class for the Sam2 model's output.")
+class Sam2ImageSegmentationOutput(ModelOutput):
+ r"""
+ iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
+ The Intersection over Union (IoU) scores of the predicted masks.
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
+ The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
+ by the processor to be brought to the original image size.
+ object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
+ Logits for the object score, indicating if an object is present.
+ image_embeddings (`tuple(torch.FloatTensor)`):
+ The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
+ tensor has shape `(batch_size, channels, height, width)`.
+ vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
+ Hidden-states of the vision model at the output of each stage.
+ vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
+ Attentions weights of the vision model.
+ mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
+ Attentions weights of the mask decoder.
+ """
+
+ iou_scores: Optional[torch.FloatTensor] = None
+ pred_masks: Optional[torch.FloatTensor] = None
+ object_score_logits: Optional[torch.FloatTensor] = None
+ image_embeddings: tuple[torch.FloatTensor, ...] = None
+ vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+class Sam2PatchEmbeddings(nn.Module):
+ r"""
+ Turns pixel values into patch embeddings for transformer consumption.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`Sam2ImageProcessorFast.__call__`] for details.
+
+ Returns:
+ embeddings (`torch.FloatTensor`):
+ Patch embeddings depend on image_size, patch_kernel_size, patch_stride and patch_padding
+ """
+
+ def __init__(self, config: Sam2HieraDetConfig):
+ super().__init__()
+ num_channels = config.num_channels
+ hidden_size = config.hidden_size
+
+ self.projection = nn.Conv2d(
+ num_channels,
+ hidden_size,
+ kernel_size=config.patch_kernel_size,
+ stride=config.patch_stride,
+ padding=config.patch_padding,
+ )
+
+ def forward(self, pixel_values):
+ _, num_channels, height, width = pixel_values.shape
+ embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
+ return embeddings
+
+
+# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
+class Sam2SinePositionEmbedding(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
+ need paper, generalized to work on images.
+ """
+
+ def __init__(
+ self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
+ ):
+ super().__init__()
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ self.scale = 2 * math.pi if scale is None else scale
+
+ @compile_compatible_method_lru_cache(maxsize=1)
+ def forward(
+ self,
+ shape: torch.Size,
+ device: Union[torch.device, str],
+ dtype: torch.dtype,
+ mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ if mask is None:
+ mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
+ not_mask = (~mask).to(dtype)
+ y_embed = not_mask.cumsum(1)
+ x_embed = not_mask.cumsum(2)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class Sam2VisionNeck(nn.Module):
+ def __init__(self, config: Sam2VisionConfig):
+ super().__init__()
+ self.config = config
+
+ self.position_encoding = Sam2SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True)
+ self.convs = nn.ModuleList()
+ for in_channels in config.backbone_channel_list:
+ self.convs.append(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=config.fpn_hidden_size,
+ kernel_size=config.fpn_kernel_size,
+ stride=config.fpn_stride,
+ padding=config.fpn_padding,
+ ),
+ )
+ self.fpn_top_down_levels = config.fpn_top_down_levels
+
+ def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
+ fpn_hidden_states = ()
+ fpn_position_encoding = ()
+
+ # forward in top-down order (from low to high resolution)
+ n = len(self.convs) - 1
+ for i in range(n, -1, -1):
+ lateral_features = hidden_states[i].permute(0, 3, 1, 2)
+ lateral_features = self.convs[n - i](lateral_features)
+ if i not in self.fpn_top_down_levels or i == n:
+ prev_features = lateral_features
+ else:
+ top_down_features = F.interpolate(
+ prev_features.to(dtype=torch.float32),
+ scale_factor=2.0,
+ mode="nearest",
+ align_corners=None,
+ antialias=False,
+ ).to(lateral_features.dtype)
+ prev_features = lateral_features + top_down_features
+
+ prev_position_encoding = self.position_encoding(
+ prev_features.shape, prev_features.device, prev_features.dtype
+ ).to(prev_features.dtype)
+
+ fpn_hidden_states += (prev_features,)
+ fpn_position_encoding += (prev_position_encoding,)
+
+ return fpn_hidden_states, fpn_position_encoding
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def do_pool(x: torch.Tensor, query_stride: Optional[int] = None) -> torch.Tensor:
+ if query_stride is None:
+ return x
+ # (B, H, W, C) -> (B, C, H, W)
+ x = x.permute(0, 3, 1, 2)
+ x = nn.functional.max_pool2d(x, kernel_size=query_stride, stride=query_stride, ceil_mode=False)
+ # (B, C, H', W') -> (B, H', W', C)
+ x = x.permute(0, 2, 3, 1)
+ return x
+
+
+class Sam2MultiScaleAttention(nn.Module):
+ def __init__(
+ self,
+ config: Sam2HieraDetConfig,
+ dim: int,
+ dim_out: int,
+ num_attention_heads: int,
+ query_stride: Optional[tuple[int, int]] = None,
+ ):
+ super().__init__()
+
+ self.config = config
+
+ self.dim = dim
+ self.dim_out = dim_out
+ self.query_stride = query_stride
+
+ self.num_attention_heads = num_attention_heads
+ head_dim = dim_out // num_attention_heads
+ self.scale = head_dim**-0.5
+ self.qkv = nn.Linear(dim, dim_out * 3)
+ self.proj = nn.Linear(dim_out, dim_out)
+
+ self.is_causal = False
+
+ def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
+ batch_size, height, width, _ = hidden_states.shape
+ # qkv with shape (B, H * W, 3, nHead, C)
+ qkv = self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
+ # q, k, v with shape (B, H * W, nheads, C)
+ query, key, value = torch.unbind(qkv, 2)
+
+ attn_weights = (query * self.scale) @ key.transpose(-2, -1)
+ attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
+
+ # Q pooling (for downsample at stage changes)
+ if self.query_stride:
+ query = do_pool(query.reshape(batch_size, height, width, -1), self.query_stride)
+ height, width = query.shape[1:3] # downsampled shape
+ query = query.reshape(batch_size, height * width, self.num_attention_heads, -1)
+
+ # transpose query, key, value to (B, nHead, H * W, C)
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ attn_output, _ = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=None,
+ is_causal=self.is_causal,
+ scaling=self.scale,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(batch_size, height, width, -1)
+
+ attn_output = self.proj(attn_output)
+
+ return attn_output
+
+
+class Sam2FeedForward(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ activation: str = "relu",
+ sigmoid_output: bool = False,
+ ):
+ super().__init__()
+ self.num_layers = num_layers
+ self.activation = ACT2FN[activation]
+ self.proj_in = nn.Linear(input_dim, hidden_dim)
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
+ self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
+ self.sigmoid_output = sigmoid_output
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ for layer in self.layers:
+ hidden_states = self.activation(layer(hidden_states))
+
+ hidden_states = self.proj_out(hidden_states)
+ if self.sigmoid_output:
+ hidden_states = F.sigmoid(hidden_states)
+ return hidden_states
+
+
+def window_partition(hidden_state, window_size):
+ """
+ Partition into non-overlapping windows with padding if needed.
+
+ Args:
+ hidden_state (`torch.Tensor`):
+ Input tokens with [batch_size, height, width, num_channels].
+ window_size (`int`):
+ Window size.
+
+ Returns:
+ `tuple(torch.FloatTensor)` comprising various elements:
+ - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
+ - (padded_height, padded_width): padded height and width before partition
+ """
+ batch_size, height, width, num_channels = hidden_state.shape
+
+ pad_height = (window_size - height % window_size) % window_size
+ pad_width = (window_size - width % window_size) % window_size
+
+ # Noop in case pad_width == 0 and pad_height == 0.
+ hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
+
+ padded_height, padded_width = height + pad_height, width + pad_width
+
+ hidden_state = hidden_state.view(
+ batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
+ )
+ windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
+ return windows, (padded_height, padded_width)
+
+
+def window_unpartition(windows, window_size, pad_height_width, height_width):
+ """
+ Window unpartition into original sequences and removing padding.
+
+ Args:
+ windows (`torch.Tensor`):
+ Input tokens with [batch_size * num_windows, window_size, window_size, num_channels].
+ window_size (`int`):
+ Window size.
+ pad_height_width (`tuple[int]`):
+ Padded height and width (padded_height, padded_width).
+ height_width (`tuple[int]`):
+ Original height and width before padding.
+
+ Returns:
+ hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
+ """
+ padded_height, padded_width = pad_height_width
+ height, width = height_width
+ batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
+ hidden_state = windows.view(
+ batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
+ )
+ hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
+ hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
+
+ # We always have height <= padded_height and width <= padded_width
+ hidden_state = hidden_state[:, :height, :width, :].contiguous()
+ return hidden_state
+
+
+class Sam2MultiScaleBlock(GradientCheckpointingLayer):
+ def __init__(
+ self,
+ config: Sam2HieraDetConfig,
+ stage_idx: int,
+ block_idx: int,
+ total_block_idx: int,
+ ):
+ super().__init__()
+
+ # take embed dim from previous stage if first block of stage
+ self.dim = (
+ config.embed_dim_per_stage[stage_idx - 1]
+ if stage_idx > 0 and block_idx == 0
+ else config.embed_dim_per_stage[stage_idx]
+ )
+ self.dim_out = config.embed_dim_per_stage[stage_idx]
+ self.layer_norm1 = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
+ # take window size from previous stage if first block of stage
+ self.window_size = (
+ config.window_size_per_stage[stage_idx - 1]
+ if stage_idx > 0 and block_idx == 0
+ else config.window_size_per_stage[stage_idx]
+ )
+ self.window_size = 0 if total_block_idx in config.global_attention_blocks else self.window_size
+ # use query stride for first block of stage if stage is a query pool stage
+ self.query_stride = (
+ config.query_stride if 0 < stage_idx <= config.num_query_pool_stages and block_idx == 0 else None
+ )
+
+ self.attn = Sam2MultiScaleAttention(
+ config,
+ self.dim,
+ self.dim_out,
+ num_attention_heads=config.num_attention_heads_per_stage[stage_idx],
+ query_stride=self.query_stride,
+ )
+ self.layer_norm2 = nn.LayerNorm(self.dim_out, eps=config.layer_norm_eps)
+ self.mlp = Sam2FeedForward(
+ self.dim_out,
+ int(self.dim_out * config.mlp_ratio),
+ self.dim_out,
+ num_layers=2,
+ activation=config.hidden_act,
+ )
+ if self.dim != self.dim_out:
+ self.proj = nn.Linear(self.dim, self.dim_out)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.FloatTensor:
+ residual = hidden_states # batch_size, height, width, channel
+
+ hidden_states = self.layer_norm1(hidden_states)
+
+ # Skip connection
+ if self.dim != self.dim_out:
+ residual = do_pool(self.proj(hidden_states), self.query_stride)
+
+ # Window partition
+ window_size = self.window_size
+ if self.window_size > 0:
+ H, W = hidden_states.shape[1], hidden_states.shape[2]
+ hidden_states, pad_hw = window_partition(hidden_states, window_size)
+
+ # Window Attention + Q Pooling (if stage change)
+ attn_output = self.attn(
+ hidden_states=hidden_states,
+ **kwargs,
+ )
+ hidden_states = attn_output
+ if self.query_stride:
+ # Shapes have changed due to Q pooling
+ window_size = self.window_size // self.query_stride[0]
+ H, W = residual.shape[1:3]
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ pad_hw = (H + pad_h, W + pad_w)
+
+ # Reverse window partition
+ if self.window_size > 0:
+ hidden_states = window_unpartition(hidden_states, window_size, pad_hw, (H, W))
+
+ hidden_states = residual + hidden_states
+ layernorm_output = self.layer_norm2(hidden_states)
+ hidden_states = hidden_states + self.mlp(layernorm_output)
+
+ return hidden_states
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Hiera model's outputs that also contains a pooling of the last hidden states.
+ """
+)
+class Sam2HieraDetModelOutput(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
+ hidden-states at the output of the last layer of the model.
+ intermediate_hidden_states (`tuple[torch.FloatTensor]` of shape `(batch_size, height, width, hidden_size)`):
+ Sequence of hidden-states at the output of the intermediate layers of the model.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ intermediate_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@auto_docstring
+class Sam2PreTrainedModel(PreTrainedModel):
+ config_class = Sam2Config
+ base_model_prefix = "sam2"
+ main_input_name = "pixel_values"
+ _supports_sdpa = True
+ _supports_flash_attn_2 = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, (nn.LayerNorm, Sam2LayerNorm)):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+ if isinstance(module, Sam2HieraDetModel):
+ if module.pos_embed is not None:
+ module.pos_embed.data.zero_()
+ if module.pos_embed_window is not None:
+ module.pos_embed_window.data.zero_()
+ if isinstance(module, Sam2Model):
+ if module.no_memory_embedding is not None:
+ module.no_memory_embedding.data.zero_()
+
+
+class Sam2HieraDetModel(Sam2PreTrainedModel):
+ config_class = Sam2HieraDetConfig
+ main_input_name = "pixel_values"
+ _can_record_outputs = {
+ "hidden_states": Sam2MultiScaleBlock,
+ "attentions": Sam2MultiScaleAttention,
+ }
+
+ def __init__(self, config: Sam2HieraDetConfig):
+ super().__init__(config)
+
+ self.patch_embed = Sam2PatchEmbeddings(config)
+ # Windowed positional embedding (https://huggingface.co/papers/2311.05613)
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, config.hidden_size, *config.window_positional_embedding_background_size)
+ )
+ self.pos_embed_window = nn.Parameter(
+ torch.zeros(1, config.hidden_size, config.window_size_per_stage[0], config.window_size_per_stage[0])
+ )
+ self.stage_ends = (np.cumsum(config.blocks_per_stage) - 1).tolist()
+ self.blocks = nn.ModuleList()
+ total_block_idx = 0
+ for stage_idx, blocks_per_stage in enumerate(config.blocks_per_stage):
+ for block_idx in range(blocks_per_stage):
+ block = Sam2MultiScaleBlock(
+ config=config, stage_idx=stage_idx, block_idx=block_idx, total_block_idx=total_block_idx
+ )
+ self.blocks.append(block)
+ total_block_idx += 1
+
+ def get_input_embeddings(self):
+ return self.patch_embed
+
+ def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor:
+ h, w = hw
+ window_embed = self.pos_embed_window
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
+ pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
+ return pos_embed
+
+ @check_model_inputs
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Sam2HieraDetModelOutput]:
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.patch_embed(pixel_values)
+ hidden_states = hidden_states + self._get_pos_embed(hidden_states.shape[1:3])
+
+ intermediate_hidden_states = ()
+ for i, block_module in enumerate(self.blocks):
+ hidden_states = block_module(hidden_states, **kwargs)
+
+ if i in self.stage_ends:
+ intermediate_hidden_states = intermediate_hidden_states + (hidden_states,)
+
+ return Sam2HieraDetModelOutput(
+ last_hidden_state=hidden_states,
+ intermediate_hidden_states=intermediate_hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The vision model from Sam without any head or projection on top.
+ """
+)
+class Sam2VisionModel(Sam2PreTrainedModel):
+ config_class = Sam2VisionConfig
+ main_input_name = "pixel_values"
+ _can_record_outputs = {
+ "hidden_states": Sam2MultiScaleBlock,
+ "attentions": Sam2MultiScaleAttention,
+ }
+
+ def __init__(self, config: Sam2VisionConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.backbone = AutoModel.from_config(config.backbone_config)
+
+ self.neck = Sam2VisionNeck(config)
+ self.num_feature_levels = config.num_feature_levels
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.backbone.get_input_embeddings()
+
+ @check_model_inputs
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Sam2VisionEncoderOutput]:
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Forward through backbone
+ backbone_output = self.backbone(pixel_values, **kwargs)
+ hidden_states = backbone_output.last_hidden_state
+ intermediate_hidden_states = backbone_output.intermediate_hidden_states
+
+ fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
+ # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
+ fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
+ fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
+
+ return Sam2VisionEncoderOutput(
+ last_hidden_state=hidden_states,
+ fpn_hidden_states=fpn_hidden_states,
+ fpn_position_encoding=fpn_position_encoding,
+ )
+
+
+class Sam2PositionalEmbedding(nn.Module):
+ def __init__(self, config: Sam2PromptEncoderConfig):
+ super().__init__()
+ self.scale = config.scale
+ positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
+ self.register_buffer("positional_embedding", positional_embedding)
+
+ def forward(self, input_coords, input_shape=None):
+ """Positionally encode points that are normalized to [0,1]."""
+ coordinates = input_coords.clone()
+
+ if input_shape is not None:
+ coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
+ coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
+ coordinates.to(torch.float32)
+
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coordinates = 2 * coordinates - 1
+ coordinates = coordinates.to(self.positional_embedding.dtype)
+ coordinates = coordinates @ self.positional_embedding
+ coordinates = 2 * np.pi * coordinates
+ # outputs d_1 x ... x d_n x channel shape
+ return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
+
+
+class Sam2MaskEmbedding(nn.Module):
+ def __init__(self, config: Sam2PromptEncoderConfig):
+ super().__init__()
+ self.mask_input_channels = config.mask_input_channels // 4
+ self.activation = ACT2FN[config.hidden_act]
+ self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
+ self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
+ self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
+ self.layer_norm1 = Sam2LayerNorm(
+ self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
+ )
+ self.layer_norm2 = Sam2LayerNorm(
+ self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
+ )
+
+ def forward(self, masks):
+ hidden_states = self.conv1(masks)
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ dense_embeddings = self.conv3(hidden_states)
+ return dense_embeddings
+
+
+class Sam2PromptEncoder(nn.Module):
+ def __init__(self, config: Sam2PromptEncoderConfig):
+ super().__init__()
+ self.shared_embedding = Sam2PositionalEmbedding(config)
+ self.mask_embed = Sam2MaskEmbedding(config)
+ self.no_mask_embed = nn.Embedding(1, config.hidden_size)
+
+ self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
+ self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size)
+ self.input_image_size = config.image_size
+
+ self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size)
+ self.hidden_size = config.hidden_size
+ self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
+
+ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0)
+ labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1)
+ input_shape = (self.input_image_size, self.input_image_size)
+ point_embedding = self.shared_embedding(points, input_shape)
+
+ # torch.where and expanding the labels tensor is required by the ONNX export
+ point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
+
+ # This is required for the ONNX export. The dtype, device need to be explicitly
+ # specified as otherwise torch.onnx.export interprets as double
+ point_embedding = torch.where(
+ labels[..., None] != -10,
+ point_embedding,
+ torch.zeros_like(point_embedding),
+ )
+
+ # Add point embeddings for labels >= 0
+ point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1)
+
+ return point_embedding
+
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts."""
+ boxes += 0.5 # Shift to center of pixel
+ coords = boxes.view(*boxes.shape[:2], 2, 2)
+ # add padding point for consistency with the original implementation
+ coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
+ corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
+ corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
+ corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
+ corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
+ return corner_embedding
+
+ def forward(
+ self,
+ input_points: Optional[tuple[torch.Tensor, torch.Tensor]],
+ input_labels: Optional[torch.Tensor],
+ input_boxes: Optional[torch.Tensor],
+ input_masks: Optional[torch.Tensor],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Embeds different types of prompts, returning both sparse and dense embeddings.
+
+ Args:
+ points (`torch.Tensor`, *optional*):
+ point coordinates and labels to embed.
+ boxes (`torch.Tensor`, *optional*):
+ boxes to embed
+ masks (`torch.Tensor`, *optional*):
+ masks to embed
+ """
+ sparse_embeddings = None
+ batch_size = 1
+ if input_points is not None:
+ batch_size = input_points.shape[0]
+ if input_labels is None:
+ raise ValueError("If points are provided, labels must also be provided.")
+ point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
+ sparse_embeddings = point_embeddings
+ if input_boxes is not None:
+ batch_size = input_boxes.shape[0]
+ box_embeddings = self._embed_boxes(input_boxes)
+ if sparse_embeddings is None:
+ sparse_embeddings = box_embeddings
+ else:
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
+ if input_masks is not None:
+ dense_embeddings = self.mask_embed(input_masks)
+ else:
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+ batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+ )
+
+ return sparse_embeddings, dense_embeddings
+
+
+class Sam2Attention(nn.Module):
+ """
+ SAM2's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
+ values.
+ """
+
+ def __init__(self, config, downsample_rate=None):
+ super().__init__()
+ downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.internal_dim = config.hidden_size // downsample_rate
+ self.num_attention_heads = config.num_attention_heads
+ self.head_dim = self.internal_dim // config.num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_similarity: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # Input projections
+ batch_size, point_batch_size = query.shape[:2]
+ new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
+
+ query = self.q_proj(query).view(*new_shape).transpose(1, 2)
+ key = self.k_proj(key).view(*new_shape).transpose(1, 2)
+ value = self.v_proj(value).view(*new_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=attention_similarity,
+ dropout=0.0,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(
+ batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
+ ).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class Sam2TwoWayAttentionBlock(nn.Module):
+ def __init__(self, config: Sam2MaskDecoderConfig, skip_first_layer_pe: bool = False):
+ """
+ A transformer block with four layers:
+ (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
+ sparse inputs (4) cross attention of dense inputs -> sparse inputs
+
+ Arguments:
+ config (`Sam2MaskDecoderConfig`):
+ The configuration file used to instantiate the block
+ attention_downsample_rate (*optionalk*, int, defaults to 2):
+ The downsample ratio of the block used to reduce the inner dim of the attention.
+ skip_first_layer_pe (*optional*, bool, defaults to `False`):
+ Whether or not to skip the addition of the query_point_embedding on the first layer.
+ """
+ super().__init__()
+ self.self_attn = Sam2Attention(config, downsample_rate=1)
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size)
+
+ self.cross_attn_token_to_image = Sam2Attention(config)
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size)
+
+ self.mlp = Sam2FeedForward(
+ config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers
+ )
+ self.layer_norm3 = nn.LayerNorm(config.hidden_size)
+
+ self.layer_norm4 = nn.LayerNorm(config.hidden_size)
+ self.cross_attn_image_to_token = Sam2Attention(config)
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(
+ self,
+ queries: Tensor,
+ keys: Tensor,
+ query_point_embedding: Tensor,
+ key_point_embedding: Tensor,
+ attention_similarity: Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries, _ = self.self_attn(query=queries, key=queries, value=queries)
+ else:
+ query = queries + query_point_embedding
+ attn_out, _ = self.self_attn(query=query, key=query, value=queries)
+ queries = queries + attn_out
+ queries = self.layer_norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ query = queries + query_point_embedding
+ key = keys + key_point_embedding
+
+ attn_out, _ = self.cross_attn_token_to_image(
+ query=query, key=key, value=keys, attention_similarity=attention_similarity
+ )
+ queries = queries + attn_out
+
+ queries = self.layer_norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.layer_norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ query = queries + query_point_embedding
+ key = keys + key_point_embedding
+
+ attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
+ keys = keys + attn_out
+
+ keys = self.layer_norm4(keys)
+ return queries, keys, attn_out
+
+
+class Sam2TwoWayTransformer(nn.Module):
+ def __init__(self, config: Sam2MaskDecoderConfig):
+ super().__init__()
+ self.config = config
+
+ self.num_hidden_layers = config.num_hidden_layers
+ self.layers = nn.ModuleList()
+
+ for i in range(self.num_hidden_layers):
+ self.layers.append(Sam2TwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
+
+ self.final_attn_token_to_image = Sam2Attention(config)
+ self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ point_embeddings: Tensor,
+ image_embeddings: Tensor,
+ image_positional_embeddings: Tensor,
+ attention_similarity: Tensor,
+ target_embedding=None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutput]:
+ if image_embeddings is None:
+ raise ValueError("You have to specify an image_embedding")
+
+ image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
+ image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
+
+ # Prepare queries
+ queries = point_embeddings
+ keys = image_embeddings
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ if target_embedding is not None:
+ queries += target_embedding
+
+ queries, keys, _ = layer(
+ queries=queries,
+ keys=keys,
+ query_point_embedding=point_embeddings,
+ key_point_embedding=image_positional_embeddings,
+ attention_similarity=attention_similarity,
+ **kwargs,
+ )
+ # Apply the final attention layer from the points to the image
+ query = queries + point_embeddings
+ key = keys + image_positional_embeddings
+
+ attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
+
+ queries = queries + attn_out
+ queries = self.layer_norm_final_attn(queries)
+ return queries, keys
+
+
+class Sam2LayerNorm(nn.LayerNorm):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
+ super().__init__(normalized_shape, eps=eps, **kwargs)
+ if data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError(f"Unsupported data format: {data_format}")
+ self.data_format = data_format
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
+ """
+ if self.data_format == "channels_first":
+ features = features.permute(0, 2, 3, 1)
+ features = super().forward(features)
+ features = features.permute(0, 3, 1, 2)
+ else:
+ features = super().forward(features)
+ return features
+
+
+class Sam2MaskDecoder(nn.Module):
+ def __init__(self, config: Sam2MaskDecoderConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+
+ self.num_multimask_outputs = config.num_multimask_outputs
+ self.num_mask_tokens = config.num_multimask_outputs + 1
+
+ self.iou_token = nn.Embedding(1, self.hidden_size)
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
+
+ self.transformer = Sam2TwoWayTransformer(config)
+
+ # should we create a new class for this?
+ self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
+ self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
+ self.upscale_layer_norm = Sam2LayerNorm(self.hidden_size // 4, data_format="channels_first")
+ self.activation = nn.GELU()
+
+ mlps_list = []
+ for _ in range(self.num_mask_tokens):
+ mlps_list += [Sam2FeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
+ self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
+ self.iou_prediction_head = Sam2FeedForward(
+ self.hidden_size,
+ config.iou_head_hidden_dim,
+ self.num_mask_tokens,
+ config.iou_head_depth,
+ sigmoid_output=True,
+ )
+
+ self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1)
+ self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1)
+
+ self.obj_score_token = nn.Embedding(1, self.hidden_size)
+ self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3)
+
+ self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
+ self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
+ self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_positional_embeddings: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ high_resolution_features: list[torch.Tensor],
+ attention_similarity: Optional[torch.Tensor] = None,
+ target_embedding: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Args:
+ image_embeddings (`torch.Tensor`):
+ The embeddings from the image encoder.
+ image_positional_embeddings (`torch.Tensor`):
+ Positional encoding with the shape of image_embeddings.
+ sparse_prompt_embeddings (`torch.Tensor`):
+ The embeddings of the points and boxes.
+ dense_prompt_embeddings (`torch.Tensor`):
+ The embeddings of the mask inputs.
+ multimask_output (`bool`):
+ Whether to return multiple masks or a single mask.
+ high_resolution_features (`list[torch.Tensor]`, *optional*):
+ The high-resolution features from the vision encoder.
+ attention_similarity (`torch.Tensor`, *optional*):
+ The attention similarity tensor.
+ target_embedding (`torch.Tensor`, *optional*):
+ The target embedding.
+ """
+ batch_size, num_channels, height, width = image_embeddings.shape
+ point_batch_size = sparse_prompt_embeddings.shape[1]
+ # Concatenate output tokens
+ output_tokens = torch.cat(
+ [
+ self.obj_score_token.weight,
+ self.iou_token.weight,
+ self.mask_tokens.weight,
+ ],
+ dim=0,
+ )
+ output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
+
+ if sparse_prompt_embeddings.shape[0] != 0:
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
+ else:
+ tokens = output_tokens
+ point_embeddings = tokens.to(self.iou_token.weight.dtype)
+
+ # Expand per-image data in batch direction to be per-mask
+ image_embeddings = image_embeddings + dense_prompt_embeddings
+ image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0)
+ image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
+ # Run the transformer
+ point_embeddings, image_embeddings = self.transformer(
+ point_embeddings=point_embeddings,
+ image_embeddings=image_embeddings,
+ image_positional_embeddings=image_positional_embeddings,
+ attention_similarity=attention_similarity,
+ target_embedding=target_embedding,
+ **kwargs,
+ )
+ iou_token_out = point_embeddings[:, :, 1, :]
+ mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ image_embeddings = image_embeddings.transpose(2, 3).view(
+ batch_size * point_batch_size, num_channels, height, width
+ )
+
+ feat_s0, feat_s1 = high_resolution_features
+ feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0)
+ feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0)
+ upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1
+ upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
+ upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0)
+
+ hyper_in_list: list[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ current_mlp = self.output_hypernetworks_mlps[i]
+ hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
+ hyper_in = torch.stack(hyper_in_list, dim=2)
+
+ _, num_channels, height, width = upscaled_embedding.shape
+ upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width)
+ masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+ object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :])
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ mask_slice = slice(1, None)
+ masks = masks[:, :, mask_slice, :, :]
+ iou_pred = iou_pred[:, :, mask_slice]
+ elif self.dynamic_multimask_via_stability and not self.training:
+ mask_slice = slice(0, 1)
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
+ else:
+ mask_slice = slice(0, 1)
+ masks = masks[:, :, mask_slice, :, :]
+ iou_pred = iou_pred[:, :, mask_slice]
+
+ sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
+
+ return masks, iou_pred, sam_tokens_out, object_score_logits
+
+ def _get_stability_scores(self, mask_logits):
+ """
+ Compute stability scores of the mask logits based on the IoU between upper and
+ lower thresholds.
+ """
+ mask_logits = mask_logits.flatten(-2)
+ stability_delta = self.dynamic_multimask_stability_delta
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
+ return stability_scores
+
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
+ """
+ When outputting a single mask, if the stability score from the current single-mask
+ output (based on output token 0) falls below a threshold, we instead select from
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
+ """
+ # The best mask from multimask output tokens (1~3)
+ multimask_logits = all_mask_logits[:, :, 1:, :, :]
+ multimask_iou_scores = all_iou_scores[:, :, 1:]
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P]
+ best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ best_scores_inds_expanded = best_scores_inds_expanded.expand(
+ -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)
+ )
+ best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W]
+ best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1]
+
+ # The mask from singlemask output token 0 and its stability score
+ singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
+ singlemask_iou_scores = all_iou_scores[:, :, 0:1]
+ stability_scores = self._get_stability_scores(singlemask_logits)
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
+
+ # Dynamically fall back to best multimask output upon low stability scores.
+ mask_logits_out = torch.where(
+ is_stable[..., None, None].expand_as(singlemask_logits),
+ singlemask_logits,
+ best_multimask_logits,
+ )
+ iou_scores_out = torch.where(
+ is_stable.expand_as(singlemask_iou_scores),
+ singlemask_iou_scores,
+ best_multimask_iou_scores,
+ )
+ return mask_logits_out, iou_scores_out
+
+
+@auto_docstring(
+ custom_intro="""
+ Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and
+ input points and labels, boxes, or masks.
+ """
+)
+class Sam2Model(Sam2PreTrainedModel):
+ _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
+ # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
+ _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
+ _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2TwoWayAttentionBlock, index=2)}
+ _keys_to_ignore_on_load_unexpected = [
+ r"^memory_.*",
+ r"^mask_downsample.*",
+ r"^object_pointer_proj.*",
+ r"^temporal_positional_encoding_projection_layer.*",
+ "no_memory_positional_encoding",
+ "no_object_pointer",
+ "occlusion_spatial_embedding_parameter",
+ ]
+
+ def __init__(self, config: Sam2Config):
+ super().__init__(config)
+ self.shared_image_embedding = Sam2PositionalEmbedding(config.prompt_encoder_config)
+ self.vision_encoder = AutoModel.from_config(config.vision_config)
+ self.prompt_encoder = Sam2PromptEncoder(config.prompt_encoder_config)
+ # The module using it is not a PreTrainedModel subclass so we need this
+ config.mask_decoder_config._attn_implementation = config._attn_implementation
+ self.mask_decoder = Sam2MaskDecoder(config.mask_decoder_config)
+
+ self.num_feature_levels = config.vision_config.num_feature_levels
+ self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes
+ # a single token to indicate no memory embedding from previous frames
+ self.hidden_dim = config.vision_config.fpn_hidden_size
+ self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+
+ self.post_init()
+
+ def _tie_weights(self):
+ self.prompt_encoder.shared_embedding.positional_embedding.data = (
+ self.shared_image_embedding.positional_embedding.data
+ )
+
+ def get_input_embeddings(self):
+ return self.vision_encoder.get_input_embeddings()
+
+ def get_image_wide_positional_embeddings(self) -> torch.Tensor:
+ size = self.prompt_encoder.image_embedding_size
+ target_device = self.shared_image_embedding.positional_embedding.device
+ target_dtype = self.shared_image_embedding.positional_embedding.dtype
+ grid = torch.ones(size, device=target_device, dtype=target_dtype)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / size[0]
+ x_embed = x_embed / size[1]
+
+ positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
+ return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
+
+ @torch.no_grad()
+ def get_image_embeddings(
+ self,
+ pixel_values: torch.FloatTensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> list[torch.Tensor]:
+ r"""
+ Returns the image embeddings by passing the pixel values through the vision encoder.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Input pixel values
+ """
+ batch_size = pixel_values.shape[0]
+ feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs)
+
+ # add no memory embedding to the last feature map
+ feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
+
+ # reshape feature maps to the same shape as the backbone feature sizes
+ image_embeddings = [
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+ for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
+ ]
+
+ return image_embeddings
+
+ @torch.no_grad()
+ def get_prompt_embeddings(
+ self,
+ input_points: Optional[torch.FloatTensor] = None,
+ input_labels: Optional[torch.LongTensor] = None,
+ input_boxes: Optional[torch.FloatTensor] = None,
+ input_masks: Optional[torch.LongTensor] = None,
+ ):
+ r"""
+ Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
+
+ Args:
+ input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
+ Optional input points for the prompt encoder. The padding of the point is automatically done by the
+ processor. `point_batch_size` refers to the number of masks that we want the model to predict per
+ point. The model will output `point_batch_size` times 3 masks in total.
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
+ Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
+ processor, or can be fed by the user.
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
+ Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
+ processor. users can also pass manually the input boxes.
+ input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
+ Optional input masks for the prompt encoder.
+ """
+ prompt_output = self.prompt_encoder(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+ return prompt_output
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ input_points: Optional[torch.FloatTensor] = None,
+ input_labels: Optional[torch.LongTensor] = None,
+ input_boxes: Optional[torch.FloatTensor] = None,
+ input_masks: Optional[torch.LongTensor] = None,
+ image_embeddings: Optional[torch.FloatTensor] = None,
+ multimask_output: bool = True,
+ attention_similarity: Optional[torch.FloatTensor] = None,
+ target_embedding: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Sam2ImageSegmentationOutput:
+ r"""
+ input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
+ Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
+ better results. The points can be obtained by passing a list of list of list to the processor that will
+ create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
+ second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
+ per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
+ multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
+ coordinates of the point. If a different number of points is passed either for each image, or for each
+ mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
+ computation of the embedding will be skipped for these points using the labels.
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
+ Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
+ official implementation, there are 3 types of labels
+
+ - `1`: the point is a point that contains the object of interest
+ - `0`: the point is a point that does not contain the object of interest
+ - `-1`: the point corresponds to the background
+
+ We added the label:
+
+ - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
+
+ The padding labels should be automatically done by the processor.
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
+ Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
+ much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
+ that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
+ size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
+ In the order (`x1`, `y1`, `x2`, `y2`):
+
+ - `x1`: the x coordinate of the top left point of the input box
+ - `y1`: the y coordinate of the top left point of the input box
+ - `x2`: the x coordinate of the bottom right point of the input box
+ - `y2`: the y coordinate of the bottom right point of the input box
+ input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
+ SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
+ generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
+ manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
+ image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
+ Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
+ efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
+ method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
+ multimask_output (`bool`, *optional*):
+ In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
+ bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
+ "best" mask, by specifying `multimask_output=False`.
+ attention_similarity (`torch.FloatTensor`, *optional*):
+ Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
+ model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
+ target_embedding (`torch.FloatTensor`, *optional*):
+ Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
+ the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoModel, AutoProcessor
+
+ >>> model = AutoModel.from_pretrained("danelcsb/sam2.1_hiera_tiny")
+ >>> processor = AutoProcessor.from_pretrained("danelcsb/sam2.1_hiera_tiny")
+
+ >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
+ >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
+ >>> input_points = [[[400, 650]]] # 2D location of a window on the car
+ >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
+
+ >>> # Get segmentation mask
+ >>> outputs = model(**inputs)
+
+ >>> # Postprocess masks
+ >>> masks = processor.post_process_masks(
+ ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
+ ... )
+ ```
+ """
+ if not ((pixel_values is None) ^ (image_embeddings is None)):
+ raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
+ if input_points is not None and input_boxes is not None:
+ if input_points.shape[1] != input_boxes.shape[1]:
+ raise ValueError(
+ f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
+ )
+
+ image_positional_embeddings = self.get_image_wide_positional_embeddings()
+ # repeat with batch size
+ batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
+ image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
+
+ vision_attentions = None
+ vision_hidden_states = None
+
+ if pixel_values is not None:
+ feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features(
+ pixel_values,
+ **kwargs,
+ )
+
+ # add no memory embedding to the last feature map
+ feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
+
+ # reshape feature maps to the same shape as the backbone feature sizes
+ image_embeddings = [
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+ for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
+ ]
+
+ if input_points is not None and input_labels is None:
+ input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
+
+ if input_points is None and input_boxes is None:
+ # If no points are provide, pad with an empty point (with label -1)
+ input_points = torch.zeros(
+ batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
+ )
+ input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
+
+ if input_masks is not None:
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
+ # and feed it as a dense mask prompt into the SAM mask encoder
+ if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
+ input_masks = F.interpolate(
+ input_masks.float(),
+ size=self.prompt_encoder.mask_input_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ ).to(input_masks.dtype)
+
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+ low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder(
+ image_embeddings=image_embeddings[-1],
+ image_positional_embeddings=image_positional_embeddings,
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ high_resolution_features=image_embeddings[:-1],
+ attention_similarity=attention_similarity,
+ target_embedding=target_embedding,
+ **kwargs,
+ )
+
+ return Sam2ImageSegmentationOutput(
+ iou_scores=iou_scores,
+ pred_masks=low_res_multimasks,
+ object_score_logits=object_score_logits,
+ image_embeddings=image_embeddings,
+ vision_hidden_states=vision_hidden_states,
+ vision_attentions=vision_attentions,
+ )
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[
+ list[torch.Tensor],
+ list[torch.Tensor],
+ Optional[tuple[torch.FloatTensor, ...]],
+ Optional[tuple[torch.FloatTensor, ...]],
+ ]:
+ r"""
+ Extract and preprocess image features using the vision encoder.
+
+ Args:
+ pixel_values (`torch.FloatTensor`):
+ Input pixel values of shape `(batch_size, num_channels, height, width)`.
+
+ Returns:
+ `tuple`: A tuple containing:
+ - feature_maps (`list[torch.Tensor]`): List of feature maps from different levels.
+ - feature_maps_position_embeddings (`list[torch.Tensor]`): List of positional embeddings for each feature level.
+ - vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Hidden states from the vision encoder.
+ - vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from the vision encoder.
+ """
+ vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder(
+ pixel_values,
+ **kwargs,
+ )
+
+ feature_maps = vision_outputs.fpn_hidden_states
+ feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
+
+ # precompute projected level 0 and level 1 features in SAM decoder
+ # to avoid running it again on every SAM click
+ feature_maps = list(feature_maps)
+ feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0])
+ feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1])
+
+ # flatten NxCxHxW to HWxNxC
+ feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps]
+ feature_maps_position_embeddings = [
+ feature_map_position_embedding.flatten(2).permute(2, 0, 1)
+ for feature_map_position_embedding in feature_maps_position_embeddings
+ ]
+
+ return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions
+
+
+__all__ = ["Sam2Model", "Sam2VisionModel", "Sam2PreTrainedModel", "Sam2HieraDetModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/modular_sam2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/modular_sam2.py
new file mode 100644
index 0000000000000000000000000000000000000000..daab10855512ec2a226848db7f0dbd33e251c097
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/modular_sam2.py
@@ -0,0 +1,1463 @@
+# coding=utf-8
+# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch SAM 2 model."""
+
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...activations import ACT2FN
+from ...image_processing_utils import BatchFeature, get_size_dict
+from ...image_processing_utils_fast import BaseImageProcessorFast, DefaultFastImageProcessorKwargs
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+ pil_torch_interpolation_mapping,
+)
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ ModelOutput,
+ TensorType,
+ auto_docstring,
+ logging,
+)
+from ...utils.generic import TransformersKwargs, check_model_inputs
+from ..auto import AutoModel
+from ..maskformer.modeling_maskformer import MaskFormerSinePositionEmbedding
+from ..sam.image_processing_sam_fast import SamImageProcessorFast
+from ..sam.modeling_sam import (
+ SamLayerNorm,
+ SamMaskDecoder,
+ SamMaskEmbedding,
+ SamModel,
+ SamPromptEncoder,
+ SamTwoWayAttentionBlock,
+ SamTwoWayTransformer,
+ eager_attention_forward,
+)
+from ..vitdet.modeling_vitdet import window_partition, window_unpartition
+from .configuration_sam2 import (
+ Sam2Config,
+ Sam2HieraDetConfig,
+ Sam2MaskDecoderConfig,
+ Sam2PromptEncoderConfig,
+ Sam2VisionConfig,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class Sam2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ r"""
+ mask_size (`dict[str, int]`, *optional*):
+ The size `{"height": int, "width": int}` to resize the segmentation maps to.
+ """
+
+ mask_size: Optional[dict[str, int]]
+
+
+@auto_docstring
+class Sam2ImageProcessorFast(SamImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_DEFAULT_MEAN
+ image_std = IMAGENET_DEFAULT_STD
+ size = {"height": 1024, "width": 1024}
+ mask_size = {"height": 256, "width": 256}
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+
+ valid_kwargs = Sam2FastImageProcessorKwargs
+
+ # modular artefacts
+ do_pad = None
+ pad_size = None
+ mask_pad_size = None
+
+ def __init__(self, **kwargs: Unpack[Sam2FastImageProcessorKwargs]):
+ BaseImageProcessorFast.__init__(self, **kwargs)
+
+ def pad_image(self):
+ raise NotImplementedError("No pad_image for SAM 2.")
+
+ def _get_preprocess_shape(self):
+ raise NotImplementedError("No _get_preprocess_shape for SAM 2.")
+
+ def resize(self):
+ raise NotImplementedError("No need to override resize for SAM 2.")
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> "torch.Tensor":
+ return BaseImageProcessorFast._preprocess(self, images, return_tensors=return_tensors, **kwargs).pixel_values
+
+ def _preprocess_image_like_inputs(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput],
+ do_convert_rgb: bool,
+ input_data_format: ChannelDimension,
+ device: Optional[Union[str, "torch.device"]] = None,
+ **kwargs: Unpack[Sam2FastImageProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Preprocess image-like inputs.
+ """
+ images = self._prepare_image_like_inputs(
+ images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
+ )
+ original_sizes = [image.shape[-2:] for image in images]
+ images_kwargs = kwargs.copy()
+ pixel_values = self._preprocess(images, **images_kwargs)
+ reshaped_input_sizes = [image.shape[-2:] for image in images]
+ data = {
+ "pixel_values": pixel_values,
+ "original_sizes": original_sizes,
+ "reshaped_input_sizes": reshaped_input_sizes,
+ }
+
+ if segmentation_maps is not None:
+ processed_segmentation_maps = self._prepare_image_like_inputs(
+ images=segmentation_maps,
+ expected_ndims=2,
+ do_convert_rgb=False,
+ input_data_format=ChannelDimension.FIRST,
+ )
+
+ segmentation_maps_kwargs = kwargs.copy()
+ segmentation_maps_kwargs.update(
+ {
+ "do_normalize": False,
+ "do_rescale": False,
+ "interpolation": pil_torch_interpolation_mapping[PILImageResampling.NEAREST],
+ "size": segmentation_maps_kwargs.pop("mask_size"),
+ }
+ )
+ processed_segmentation_maps = self._preprocess(
+ images=processed_segmentation_maps, **segmentation_maps_kwargs
+ )
+ data["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64)
+
+ return BatchFeature(data=data, tensor_type=kwargs["return_tensors"])
+
+ def _further_process_kwargs(
+ self,
+ size: Optional[SizeDict] = None,
+ mask_size: Optional[SizeDict] = None,
+ default_to_square: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ **kwargs,
+ ) -> dict:
+ """
+ Update kwargs that need further processing before being validated
+ Can be overridden by subclasses to customize the processing of kwargs.
+ """
+ if kwargs is None:
+ kwargs = {}
+ if size is not None:
+ size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
+ if mask_size is not None:
+ mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size"))
+ if isinstance(image_mean, list):
+ image_mean = tuple(image_mean)
+ if isinstance(image_std, list):
+ image_std = tuple(image_std)
+ if data_format is None:
+ data_format = ChannelDimension.FIRST
+
+ kwargs["size"] = size
+ kwargs["mask_size"] = mask_size
+ kwargs["image_mean"] = image_mean
+ kwargs["image_std"] = image_std
+ kwargs["data_format"] = data_format
+
+ # torch resize uses interpolation instead of resample
+ # Check if resample is an int before checking if it's an instance of PILImageResampling
+ # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
+ # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
+ resample = kwargs.pop("resample")
+ kwargs["interpolation"] = (
+ pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
+ )
+
+ return kwargs
+
+ def _apply_non_overlapping_constraints(self, pred_masks: torch.Tensor) -> torch.Tensor:
+ """
+ Apply non-overlapping constraints to the object scores in pred_masks. Here we
+ keep only the highest scoring object at each spatial location in pred_masks.
+ """
+ batch_size = pred_masks.size(0)
+ if batch_size == 1:
+ return pred_masks
+
+ device = pred_masks.device
+ # "max_obj_inds": object index of the object with the highest score at each location
+ max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
+ # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
+ batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
+ keep = max_obj_inds == batch_obj_inds
+ # suppress overlapping regions' scores below -10.0 so that the foreground regions
+ # don't overlap (here sigmoid(-10.0)=4.5398e-05)
+ pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
+ return pred_masks
+
+ def post_process_masks(
+ self,
+ masks,
+ original_sizes,
+ mask_threshold=0.0,
+ binarize=True,
+ max_hole_area=0.0,
+ max_sprinkle_area=0.0,
+ apply_non_overlapping_constraints=False,
+ **kwargs,
+ ):
+ """
+ Remove padding and upscale masks to the original image size.
+
+ Args:
+ masks (`Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]]`):
+ Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
+ original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
+ The original sizes of each image before it was resized to the model's expected input shape, in (height,
+ width) format.
+ mask_threshold (`float`, *optional*, defaults to 0.0):
+ Threshold for binarization and post-processing operations.
+ binarize (`bool`, *optional*, defaults to `True`):
+ Whether to binarize the masks.
+ max_hole_area (`float`, *optional*, defaults to 0.0):
+ The maximum area of a hole to fill.
+ max_sprinkle_area (`float`, *optional*, defaults to 0.0):
+ The maximum area of a sprinkle to fill.
+ apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`):
+ Whether to apply non-overlapping constraints to the masks.
+
+ Returns:
+ (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
+ is given by original_size.
+ """
+ if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
+ original_sizes = original_sizes.tolist()
+ # TODO: add connected components kernel for postprocessing
+ output_masks = []
+ for i, original_size in enumerate(original_sizes):
+ if isinstance(masks[i], np.ndarray):
+ masks[i] = torch.from_numpy(masks[i])
+ elif not isinstance(masks[i], torch.Tensor):
+ raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
+ interpolated_mask = F.interpolate(masks[i], original_size, mode="bilinear", align_corners=False)
+ if apply_non_overlapping_constraints:
+ interpolated_mask = self._apply_non_overlapping_constraints(interpolated_mask)
+ if binarize:
+ interpolated_mask = interpolated_mask > mask_threshold
+ output_masks.append(interpolated_mask)
+
+ return output_masks
+
+
+@dataclass
+@auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
+class Sam2VisionEncoderOutput(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ fpn_hidden_states (`tuple(torch.FloatTensor)`):
+ Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
+ `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
+ fpn_position_encoding (`tuple(torch.FloatTensor)`):
+ Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
+ `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the
+ model at the output of each stage.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ fpn_hidden_states: Optional[torch.FloatTensor] = None
+ fpn_position_encoding: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(custom_intro="Base class for the Sam2 model's output.")
+class Sam2ImageSegmentationOutput(ModelOutput):
+ r"""
+ iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
+ The Intersection over Union (IoU) scores of the predicted masks.
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
+ The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
+ by the processor to be brought to the original image size.
+ object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
+ Logits for the object score, indicating if an object is present.
+ image_embeddings (`tuple(torch.FloatTensor)`):
+ The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
+ tensor has shape `(batch_size, channels, height, width)`.
+ vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
+ Hidden-states of the vision model at the output of each stage.
+ vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
+ Attentions weights of the vision model.
+ mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
+ Attentions weights of the mask decoder.
+ """
+
+ iou_scores: Optional[torch.FloatTensor] = None
+ pred_masks: Optional[torch.FloatTensor] = None
+ object_score_logits: Optional[torch.FloatTensor] = None
+ image_embeddings: tuple[torch.FloatTensor, ...] = None
+ vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+class Sam2PatchEmbeddings(nn.Module):
+ r"""
+ Turns pixel values into patch embeddings for transformer consumption.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`Sam2ImageProcessorFast.__call__`] for details.
+
+ Returns:
+ embeddings (`torch.FloatTensor`):
+ Patch embeddings depend on image_size, patch_kernel_size, patch_stride and patch_padding
+ """
+
+ def __init__(self, config: Sam2HieraDetConfig):
+ super().__init__()
+ num_channels = config.num_channels
+ hidden_size = config.hidden_size
+
+ self.projection = nn.Conv2d(
+ num_channels,
+ hidden_size,
+ kernel_size=config.patch_kernel_size,
+ stride=config.patch_stride,
+ padding=config.patch_padding,
+ )
+
+ def forward(self, pixel_values):
+ _, num_channels, height, width = pixel_values.shape
+ embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
+ return embeddings
+
+
+class Sam2SinePositionEmbedding(MaskFormerSinePositionEmbedding):
+ pass
+
+
+class Sam2VisionNeck(nn.Module):
+ def __init__(self, config: Sam2VisionConfig):
+ super().__init__()
+ self.config = config
+
+ self.position_encoding = Sam2SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True)
+ self.convs = nn.ModuleList()
+ for in_channels in config.backbone_channel_list:
+ self.convs.append(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=config.fpn_hidden_size,
+ kernel_size=config.fpn_kernel_size,
+ stride=config.fpn_stride,
+ padding=config.fpn_padding,
+ ),
+ )
+ self.fpn_top_down_levels = config.fpn_top_down_levels
+
+ def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
+ fpn_hidden_states = ()
+ fpn_position_encoding = ()
+
+ # forward in top-down order (from low to high resolution)
+ n = len(self.convs) - 1
+ for i in range(n, -1, -1):
+ lateral_features = hidden_states[i].permute(0, 3, 1, 2)
+ lateral_features = self.convs[n - i](lateral_features)
+ if i not in self.fpn_top_down_levels or i == n:
+ prev_features = lateral_features
+ else:
+ top_down_features = F.interpolate(
+ prev_features.to(dtype=torch.float32),
+ scale_factor=2.0,
+ mode="nearest",
+ align_corners=None,
+ antialias=False,
+ ).to(lateral_features.dtype)
+ prev_features = lateral_features + top_down_features
+
+ prev_position_encoding = self.position_encoding(
+ prev_features.shape, prev_features.device, prev_features.dtype
+ ).to(prev_features.dtype)
+
+ fpn_hidden_states += (prev_features,)
+ fpn_position_encoding += (prev_position_encoding,)
+
+ return fpn_hidden_states, fpn_position_encoding
+
+
+def do_pool(x: torch.Tensor, query_stride: Optional[int] = None) -> torch.Tensor:
+ if query_stride is None:
+ return x
+ # (B, H, W, C) -> (B, C, H, W)
+ x = x.permute(0, 3, 1, 2)
+ x = nn.functional.max_pool2d(x, kernel_size=query_stride, stride=query_stride, ceil_mode=False)
+ # (B, C, H', W') -> (B, H', W', C)
+ x = x.permute(0, 2, 3, 1)
+ return x
+
+
+class Sam2MultiScaleAttention(nn.Module):
+ def __init__(
+ self,
+ config: Sam2HieraDetConfig,
+ dim: int,
+ dim_out: int,
+ num_attention_heads: int,
+ query_stride: Optional[tuple[int, int]] = None,
+ ):
+ super().__init__()
+
+ self.config = config
+
+ self.dim = dim
+ self.dim_out = dim_out
+ self.query_stride = query_stride
+
+ self.num_attention_heads = num_attention_heads
+ head_dim = dim_out // num_attention_heads
+ self.scale = head_dim**-0.5
+ self.qkv = nn.Linear(dim, dim_out * 3)
+ self.proj = nn.Linear(dim_out, dim_out)
+
+ self.is_causal = False
+
+ def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
+ batch_size, height, width, _ = hidden_states.shape
+ # qkv with shape (B, H * W, 3, nHead, C)
+ qkv = self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
+ # q, k, v with shape (B, H * W, nheads, C)
+ query, key, value = torch.unbind(qkv, 2)
+
+ attn_weights = (query * self.scale) @ key.transpose(-2, -1)
+ attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
+
+ # Q pooling (for downsample at stage changes)
+ if self.query_stride:
+ query = do_pool(query.reshape(batch_size, height, width, -1), self.query_stride)
+ height, width = query.shape[1:3] # downsampled shape
+ query = query.reshape(batch_size, height * width, self.num_attention_heads, -1)
+
+ # transpose query, key, value to (B, nHead, H * W, C)
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ attn_output, _ = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=None,
+ is_causal=self.is_causal,
+ scaling=self.scale,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(batch_size, height, width, -1)
+
+ attn_output = self.proj(attn_output)
+
+ return attn_output
+
+
+class Sam2FeedForward(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ activation: str = "relu",
+ sigmoid_output: bool = False,
+ ):
+ super().__init__()
+ self.num_layers = num_layers
+ self.activation = ACT2FN[activation]
+ self.proj_in = nn.Linear(input_dim, hidden_dim)
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
+ self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
+ self.sigmoid_output = sigmoid_output
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ for layer in self.layers:
+ hidden_states = self.activation(layer(hidden_states))
+
+ hidden_states = self.proj_out(hidden_states)
+ if self.sigmoid_output:
+ hidden_states = F.sigmoid(hidden_states)
+ return hidden_states
+
+
+class Sam2MultiScaleBlock(GradientCheckpointingLayer):
+ def __init__(
+ self,
+ config: Sam2HieraDetConfig,
+ stage_idx: int,
+ block_idx: int,
+ total_block_idx: int,
+ ):
+ super().__init__()
+
+ # take embed dim from previous stage if first block of stage
+ self.dim = (
+ config.embed_dim_per_stage[stage_idx - 1]
+ if stage_idx > 0 and block_idx == 0
+ else config.embed_dim_per_stage[stage_idx]
+ )
+ self.dim_out = config.embed_dim_per_stage[stage_idx]
+ self.layer_norm1 = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
+ # take window size from previous stage if first block of stage
+ self.window_size = (
+ config.window_size_per_stage[stage_idx - 1]
+ if stage_idx > 0 and block_idx == 0
+ else config.window_size_per_stage[stage_idx]
+ )
+ self.window_size = 0 if total_block_idx in config.global_attention_blocks else self.window_size
+ # use query stride for first block of stage if stage is a query pool stage
+ self.query_stride = (
+ config.query_stride if 0 < stage_idx <= config.num_query_pool_stages and block_idx == 0 else None
+ )
+
+ self.attn = Sam2MultiScaleAttention(
+ config,
+ self.dim,
+ self.dim_out,
+ num_attention_heads=config.num_attention_heads_per_stage[stage_idx],
+ query_stride=self.query_stride,
+ )
+ self.layer_norm2 = nn.LayerNorm(self.dim_out, eps=config.layer_norm_eps)
+ self.mlp = Sam2FeedForward(
+ self.dim_out,
+ int(self.dim_out * config.mlp_ratio),
+ self.dim_out,
+ num_layers=2,
+ activation=config.hidden_act,
+ )
+ if self.dim != self.dim_out:
+ self.proj = nn.Linear(self.dim, self.dim_out)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.FloatTensor:
+ residual = hidden_states # batch_size, height, width, channel
+
+ hidden_states = self.layer_norm1(hidden_states)
+
+ # Skip connection
+ if self.dim != self.dim_out:
+ residual = do_pool(self.proj(hidden_states), self.query_stride)
+
+ # Window partition
+ window_size = self.window_size
+ if self.window_size > 0:
+ H, W = hidden_states.shape[1], hidden_states.shape[2]
+ hidden_states, pad_hw = window_partition(hidden_states, window_size)
+
+ # Window Attention + Q Pooling (if stage change)
+ attn_output = self.attn(
+ hidden_states=hidden_states,
+ **kwargs,
+ )
+ hidden_states = attn_output
+ if self.query_stride:
+ # Shapes have changed due to Q pooling
+ window_size = self.window_size // self.query_stride[0]
+ H, W = residual.shape[1:3]
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ pad_hw = (H + pad_h, W + pad_w)
+
+ # Reverse window partition
+ if self.window_size > 0:
+ hidden_states = window_unpartition(hidden_states, window_size, pad_hw, (H, W))
+
+ hidden_states = residual + hidden_states
+ layernorm_output = self.layer_norm2(hidden_states)
+ hidden_states = hidden_states + self.mlp(layernorm_output)
+
+ return hidden_states
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Hiera model's outputs that also contains a pooling of the last hidden states.
+ """
+)
+class Sam2HieraDetModelOutput(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
+ hidden-states at the output of the last layer of the model.
+ intermediate_hidden_states (`tuple[torch.FloatTensor]` of shape `(batch_size, height, width, hidden_size)`):
+ Sequence of hidden-states at the output of the intermediate layers of the model.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ intermediate_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@auto_docstring
+class Sam2PreTrainedModel(PreTrainedModel):
+ config_class = Sam2Config
+ base_model_prefix = "sam2"
+ main_input_name = "pixel_values"
+ _supports_sdpa = True
+ _supports_flash_attn_2 = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, (nn.LayerNorm, Sam2LayerNorm)):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+ if isinstance(module, Sam2HieraDetModel):
+ if module.pos_embed is not None:
+ module.pos_embed.data.zero_()
+ if module.pos_embed_window is not None:
+ module.pos_embed_window.data.zero_()
+ if isinstance(module, Sam2Model):
+ if module.no_memory_embedding is not None:
+ module.no_memory_embedding.data.zero_()
+
+
+class Sam2HieraDetModel(Sam2PreTrainedModel):
+ config_class = Sam2HieraDetConfig
+ main_input_name = "pixel_values"
+ _can_record_outputs = {
+ "hidden_states": Sam2MultiScaleBlock,
+ "attentions": Sam2MultiScaleAttention,
+ }
+
+ def __init__(self, config: Sam2HieraDetConfig):
+ super().__init__(config)
+
+ self.patch_embed = Sam2PatchEmbeddings(config)
+ # Windowed positional embedding (https://huggingface.co/papers/2311.05613)
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, config.hidden_size, *config.window_positional_embedding_background_size)
+ )
+ self.pos_embed_window = nn.Parameter(
+ torch.zeros(1, config.hidden_size, config.window_size_per_stage[0], config.window_size_per_stage[0])
+ )
+ self.stage_ends = (np.cumsum(config.blocks_per_stage) - 1).tolist()
+ self.blocks = nn.ModuleList()
+ total_block_idx = 0
+ for stage_idx, blocks_per_stage in enumerate(config.blocks_per_stage):
+ for block_idx in range(blocks_per_stage):
+ block = Sam2MultiScaleBlock(
+ config=config, stage_idx=stage_idx, block_idx=block_idx, total_block_idx=total_block_idx
+ )
+ self.blocks.append(block)
+ total_block_idx += 1
+
+ def get_input_embeddings(self):
+ return self.patch_embed
+
+ def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor:
+ h, w = hw
+ window_embed = self.pos_embed_window
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
+ pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
+ return pos_embed
+
+ @check_model_inputs
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Sam2HieraDetModelOutput]:
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.patch_embed(pixel_values)
+ hidden_states = hidden_states + self._get_pos_embed(hidden_states.shape[1:3])
+
+ intermediate_hidden_states = ()
+ for i, block_module in enumerate(self.blocks):
+ hidden_states = block_module(hidden_states, **kwargs)
+
+ if i in self.stage_ends:
+ intermediate_hidden_states = intermediate_hidden_states + (hidden_states,)
+
+ return Sam2HieraDetModelOutput(
+ last_hidden_state=hidden_states,
+ intermediate_hidden_states=intermediate_hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The vision model from Sam without any head or projection on top.
+ """
+)
+class Sam2VisionModel(Sam2PreTrainedModel):
+ config_class = Sam2VisionConfig
+ main_input_name = "pixel_values"
+ _can_record_outputs = {
+ "hidden_states": Sam2MultiScaleBlock,
+ "attentions": Sam2MultiScaleAttention,
+ }
+
+ def __init__(self, config: Sam2VisionConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.backbone = AutoModel.from_config(config.backbone_config)
+
+ self.neck = Sam2VisionNeck(config)
+ self.num_feature_levels = config.num_feature_levels
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.backbone.get_input_embeddings()
+
+ @check_model_inputs
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Sam2VisionEncoderOutput]:
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Forward through backbone
+ backbone_output = self.backbone(pixel_values, **kwargs)
+ hidden_states = backbone_output.last_hidden_state
+ intermediate_hidden_states = backbone_output.intermediate_hidden_states
+
+ fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
+ # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
+ fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
+ fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
+
+ return Sam2VisionEncoderOutput(
+ last_hidden_state=hidden_states,
+ fpn_hidden_states=fpn_hidden_states,
+ fpn_position_encoding=fpn_position_encoding,
+ )
+
+
+class Sam2PositionalEmbedding(nn.Module):
+ def __init__(self, config: Sam2PromptEncoderConfig):
+ super().__init__()
+ self.scale = config.scale
+ positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
+ self.register_buffer("positional_embedding", positional_embedding)
+
+ def forward(self, input_coords, input_shape=None):
+ """Positionally encode points that are normalized to [0,1]."""
+ coordinates = input_coords.clone()
+
+ if input_shape is not None:
+ coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
+ coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
+ coordinates.to(torch.float32)
+
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coordinates = 2 * coordinates - 1
+ coordinates = coordinates.to(self.positional_embedding.dtype)
+ coordinates = coordinates @ self.positional_embedding
+ coordinates = 2 * np.pi * coordinates
+ # outputs d_1 x ... x d_n x channel shape
+ return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
+
+
+class Sam2MaskEmbedding(SamMaskEmbedding):
+ pass
+
+
+class Sam2PromptEncoder(SamPromptEncoder):
+ def __init__(self, config: Sam2PromptEncoderConfig):
+ nn.Module.__init__(self)
+ self.shared_embedding = Sam2PositionalEmbedding(config)
+ self.mask_embed = Sam2MaskEmbedding(config)
+ self.no_mask_embed = nn.Embedding(1, config.hidden_size)
+
+ self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
+ self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size)
+ self.input_image_size = config.image_size
+
+ self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size)
+ self.hidden_size = config.hidden_size
+ self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
+
+ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0)
+ labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1)
+ input_shape = (self.input_image_size, self.input_image_size)
+ point_embedding = self.shared_embedding(points, input_shape)
+
+ # torch.where and expanding the labels tensor is required by the ONNX export
+ point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
+
+ # This is required for the ONNX export. The dtype, device need to be explicitly
+ # specified as otherwise torch.onnx.export interprets as double
+ point_embedding = torch.where(
+ labels[..., None] != -10,
+ point_embedding,
+ torch.zeros_like(point_embedding),
+ )
+
+ # Add point embeddings for labels >= 0
+ point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1)
+
+ return point_embedding
+
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts."""
+ boxes += 0.5 # Shift to center of pixel
+ coords = boxes.view(*boxes.shape[:2], 2, 2)
+ # add padding point for consistency with the original implementation
+ coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
+ corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
+ corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
+ corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
+ corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
+ return corner_embedding
+
+
+class Sam2Attention(nn.Module):
+ """
+ SAM2's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
+ values.
+ """
+
+ def __init__(self, config, downsample_rate=None):
+ super().__init__()
+ downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.internal_dim = config.hidden_size // downsample_rate
+ self.num_attention_heads = config.num_attention_heads
+ self.head_dim = self.internal_dim // config.num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_similarity: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # Input projections
+ batch_size, point_batch_size = query.shape[:2]
+ new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
+
+ query = self.q_proj(query).view(*new_shape).transpose(1, 2)
+ key = self.k_proj(key).view(*new_shape).transpose(1, 2)
+ value = self.v_proj(value).view(*new_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=attention_similarity,
+ dropout=0.0,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(
+ batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
+ ).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class Sam2TwoWayAttentionBlock(SamTwoWayAttentionBlock, GradientCheckpointingLayer):
+ def __init__(self, config: Sam2MaskDecoderConfig, skip_first_layer_pe: bool = False):
+ nn.Module.__init__(self)
+ self.self_attn = Sam2Attention(config, downsample_rate=1)
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size)
+
+ self.cross_attn_token_to_image = Sam2Attention(config)
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size)
+
+ self.mlp = Sam2FeedForward(
+ config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers
+ )
+ self.layer_norm3 = nn.LayerNorm(config.hidden_size)
+
+ self.layer_norm4 = nn.LayerNorm(config.hidden_size)
+ self.cross_attn_image_to_token = Sam2Attention(config)
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+
+class Sam2TwoWayTransformer(SamTwoWayTransformer):
+ pass
+
+
+class Sam2LayerNorm(SamLayerNorm):
+ pass
+
+
+class Sam2MaskDecoder(SamMaskDecoder):
+ def __init__(self, config: Sam2MaskDecoderConfig):
+ super().__init__(config)
+ del self.iou_prediction_head
+ self.iou_prediction_head = Sam2FeedForward(
+ self.hidden_size,
+ config.iou_head_hidden_dim,
+ self.num_mask_tokens,
+ config.iou_head_depth,
+ sigmoid_output=True,
+ )
+
+ self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1)
+ self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1)
+
+ self.obj_score_token = nn.Embedding(1, self.hidden_size)
+ self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3)
+
+ self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
+ self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
+ self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
+
+ def _get_stability_scores(self, mask_logits):
+ """
+ Compute stability scores of the mask logits based on the IoU between upper and
+ lower thresholds.
+ """
+ mask_logits = mask_logits.flatten(-2)
+ stability_delta = self.dynamic_multimask_stability_delta
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
+ return stability_scores
+
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
+ """
+ When outputting a single mask, if the stability score from the current single-mask
+ output (based on output token 0) falls below a threshold, we instead select from
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
+ """
+ # The best mask from multimask output tokens (1~3)
+ multimask_logits = all_mask_logits[:, :, 1:, :, :]
+ multimask_iou_scores = all_iou_scores[:, :, 1:]
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P]
+ best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ best_scores_inds_expanded = best_scores_inds_expanded.expand(
+ -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)
+ )
+ best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W]
+ best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1]
+
+ # The mask from singlemask output token 0 and its stability score
+ singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
+ singlemask_iou_scores = all_iou_scores[:, :, 0:1]
+ stability_scores = self._get_stability_scores(singlemask_logits)
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
+
+ # Dynamically fall back to best multimask output upon low stability scores.
+ mask_logits_out = torch.where(
+ is_stable[..., None, None].expand_as(singlemask_logits),
+ singlemask_logits,
+ best_multimask_logits,
+ )
+ iou_scores_out = torch.where(
+ is_stable.expand_as(singlemask_iou_scores),
+ singlemask_iou_scores,
+ best_multimask_iou_scores,
+ )
+ return mask_logits_out, iou_scores_out
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_positional_embeddings: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ high_resolution_features: list[torch.Tensor],
+ attention_similarity: Optional[torch.Tensor] = None,
+ target_embedding: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Args:
+ image_embeddings (`torch.Tensor`):
+ The embeddings from the image encoder.
+ image_positional_embeddings (`torch.Tensor`):
+ Positional encoding with the shape of image_embeddings.
+ sparse_prompt_embeddings (`torch.Tensor`):
+ The embeddings of the points and boxes.
+ dense_prompt_embeddings (`torch.Tensor`):
+ The embeddings of the mask inputs.
+ multimask_output (`bool`):
+ Whether to return multiple masks or a single mask.
+ high_resolution_features (`list[torch.Tensor]`, *optional*):
+ The high-resolution features from the vision encoder.
+ attention_similarity (`torch.Tensor`, *optional*):
+ The attention similarity tensor.
+ target_embedding (`torch.Tensor`, *optional*):
+ The target embedding.
+ """
+ batch_size, num_channels, height, width = image_embeddings.shape
+ point_batch_size = sparse_prompt_embeddings.shape[1]
+ # Concatenate output tokens
+ output_tokens = torch.cat(
+ [
+ self.obj_score_token.weight,
+ self.iou_token.weight,
+ self.mask_tokens.weight,
+ ],
+ dim=0,
+ )
+ output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
+
+ if sparse_prompt_embeddings.shape[0] != 0:
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
+ else:
+ tokens = output_tokens
+ point_embeddings = tokens.to(self.iou_token.weight.dtype)
+
+ # Expand per-image data in batch direction to be per-mask
+ image_embeddings = image_embeddings + dense_prompt_embeddings
+ image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0)
+ image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
+ # Run the transformer
+ point_embeddings, image_embeddings = self.transformer(
+ point_embeddings=point_embeddings,
+ image_embeddings=image_embeddings,
+ image_positional_embeddings=image_positional_embeddings,
+ attention_similarity=attention_similarity,
+ target_embedding=target_embedding,
+ **kwargs,
+ )
+ iou_token_out = point_embeddings[:, :, 1, :]
+ mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ image_embeddings = image_embeddings.transpose(2, 3).view(
+ batch_size * point_batch_size, num_channels, height, width
+ )
+
+ feat_s0, feat_s1 = high_resolution_features
+ feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0)
+ feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0)
+ upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1
+ upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
+ upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0)
+
+ hyper_in_list: list[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ current_mlp = self.output_hypernetworks_mlps[i]
+ hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
+ hyper_in = torch.stack(hyper_in_list, dim=2)
+
+ _, num_channels, height, width = upscaled_embedding.shape
+ upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width)
+ masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+ object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :])
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ mask_slice = slice(1, None)
+ masks = masks[:, :, mask_slice, :, :]
+ iou_pred = iou_pred[:, :, mask_slice]
+ elif self.dynamic_multimask_via_stability and not self.training:
+ mask_slice = slice(0, 1)
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
+ else:
+ mask_slice = slice(0, 1)
+ masks = masks[:, :, mask_slice, :, :]
+ iou_pred = iou_pred[:, :, mask_slice]
+
+ sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
+
+ return masks, iou_pred, sam_tokens_out, object_score_logits
+
+
+@auto_docstring(
+ custom_intro="""
+ Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and
+ input points and labels, boxes, or masks.
+ """
+)
+class Sam2Model(SamModel):
+ _keys_to_ignore_on_load_unexpected = [
+ r"^memory_.*",
+ r"^mask_downsample.*",
+ r"^object_pointer_proj.*",
+ r"^temporal_positional_encoding_projection_layer.*",
+ "no_memory_positional_encoding",
+ "no_object_pointer",
+ "occlusion_spatial_embedding_parameter",
+ ]
+
+ def __init__(self, config: Sam2Config):
+ PreTrainedModel.__init__(self, config)
+ self.shared_image_embedding = Sam2PositionalEmbedding(config.prompt_encoder_config)
+ self.vision_encoder = AutoModel.from_config(config.vision_config)
+ self.prompt_encoder = Sam2PromptEncoder(config.prompt_encoder_config)
+ # The module using it is not a PreTrainedModel subclass so we need this
+ config.mask_decoder_config._attn_implementation = config._attn_implementation
+ self.mask_decoder = Sam2MaskDecoder(config.mask_decoder_config)
+
+ self.num_feature_levels = config.vision_config.num_feature_levels
+ self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes
+ # a single token to indicate no memory embedding from previous frames
+ self.hidden_dim = config.vision_config.fpn_hidden_size
+ self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+
+ self.post_init()
+
+ def get_image_wide_positional_embeddings(self) -> torch.Tensor:
+ size = self.prompt_encoder.image_embedding_size
+ target_device = self.shared_image_embedding.positional_embedding.device
+ target_dtype = self.shared_image_embedding.positional_embedding.dtype
+ grid = torch.ones(size, device=target_device, dtype=target_dtype)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / size[0]
+ x_embed = x_embed / size[1]
+
+ positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
+ return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
+
+ @torch.no_grad()
+ def get_image_embeddings(
+ self,
+ pixel_values: torch.FloatTensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> list[torch.Tensor]:
+ r"""
+ Returns the image embeddings by passing the pixel values through the vision encoder.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Input pixel values
+ """
+ batch_size = pixel_values.shape[0]
+ feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs)
+
+ # add no memory embedding to the last feature map
+ feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
+
+ # reshape feature maps to the same shape as the backbone feature sizes
+ image_embeddings = [
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+ for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
+ ]
+
+ return image_embeddings
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[
+ list[torch.Tensor],
+ list[torch.Tensor],
+ Optional[tuple[torch.FloatTensor, ...]],
+ Optional[tuple[torch.FloatTensor, ...]],
+ ]:
+ r"""
+ Extract and preprocess image features using the vision encoder.
+
+ Args:
+ pixel_values (`torch.FloatTensor`):
+ Input pixel values of shape `(batch_size, num_channels, height, width)`.
+
+ Returns:
+ `tuple`: A tuple containing:
+ - feature_maps (`list[torch.Tensor]`): List of feature maps from different levels.
+ - feature_maps_position_embeddings (`list[torch.Tensor]`): List of positional embeddings for each feature level.
+ - vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Hidden states from the vision encoder.
+ - vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from the vision encoder.
+ """
+ vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder(
+ pixel_values,
+ **kwargs,
+ )
+
+ feature_maps = vision_outputs.fpn_hidden_states
+ feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
+
+ # precompute projected level 0 and level 1 features in SAM decoder
+ # to avoid running it again on every SAM click
+ feature_maps = list(feature_maps)
+ feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0])
+ feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1])
+
+ # flatten NxCxHxW to HWxNxC
+ feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps]
+ feature_maps_position_embeddings = [
+ feature_map_position_embedding.flatten(2).permute(2, 0, 1)
+ for feature_map_position_embedding in feature_maps_position_embeddings
+ ]
+
+ return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ input_points: Optional[torch.FloatTensor] = None,
+ input_labels: Optional[torch.LongTensor] = None,
+ input_boxes: Optional[torch.FloatTensor] = None,
+ input_masks: Optional[torch.LongTensor] = None,
+ image_embeddings: Optional[torch.FloatTensor] = None,
+ multimask_output: bool = True,
+ attention_similarity: Optional[torch.FloatTensor] = None,
+ target_embedding: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Sam2ImageSegmentationOutput:
+ r"""
+ input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
+ Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
+ better results. The points can be obtained by passing a list of list of list to the processor that will
+ create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
+ second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
+ per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
+ multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
+ coordinates of the point. If a different number of points is passed either for each image, or for each
+ mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
+ computation of the embedding will be skipped for these points using the labels.
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
+ Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
+ official implementation, there are 3 types of labels
+
+ - `1`: the point is a point that contains the object of interest
+ - `0`: the point is a point that does not contain the object of interest
+ - `-1`: the point corresponds to the background
+
+ We added the label:
+
+ - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
+
+ The padding labels should be automatically done by the processor.
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
+ Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
+ much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
+ that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
+ size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
+ In the order (`x1`, `y1`, `x2`, `y2`):
+
+ - `x1`: the x coordinate of the top left point of the input box
+ - `y1`: the y coordinate of the top left point of the input box
+ - `x2`: the x coordinate of the bottom right point of the input box
+ - `y2`: the y coordinate of the bottom right point of the input box
+ input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
+ SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
+ generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
+ manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
+ image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
+ Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
+ efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
+ method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
+ multimask_output (`bool`, *optional*):
+ In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
+ bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
+ "best" mask, by specifying `multimask_output=False`.
+ attention_similarity (`torch.FloatTensor`, *optional*):
+ Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
+ model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
+ target_embedding (`torch.FloatTensor`, *optional*):
+ Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
+ the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoModel, AutoProcessor
+
+ >>> model = AutoModel.from_pretrained("danelcsb/sam2.1_hiera_tiny")
+ >>> processor = AutoProcessor.from_pretrained("danelcsb/sam2.1_hiera_tiny")
+
+ >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
+ >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
+ >>> input_points = [[[400, 650]]] # 2D location of a window on the car
+ >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
+
+ >>> # Get segmentation mask
+ >>> outputs = model(**inputs)
+
+ >>> # Postprocess masks
+ >>> masks = processor.post_process_masks(
+ ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
+ ... )
+ ```
+ """
+ if not ((pixel_values is None) ^ (image_embeddings is None)):
+ raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
+ if input_points is not None and input_boxes is not None:
+ if input_points.shape[1] != input_boxes.shape[1]:
+ raise ValueError(
+ f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
+ )
+
+ image_positional_embeddings = self.get_image_wide_positional_embeddings()
+ # repeat with batch size
+ batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
+ image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
+
+ vision_attentions = None
+ vision_hidden_states = None
+
+ if pixel_values is not None:
+ feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features(
+ pixel_values,
+ **kwargs,
+ )
+
+ # add no memory embedding to the last feature map
+ feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
+
+ # reshape feature maps to the same shape as the backbone feature sizes
+ image_embeddings = [
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+ for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
+ ]
+
+ if input_points is not None and input_labels is None:
+ input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
+
+ if input_points is None and input_boxes is None:
+ # If no points are provide, pad with an empty point (with label -1)
+ input_points = torch.zeros(
+ batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
+ )
+ input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
+
+ if input_masks is not None:
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
+ # and feed it as a dense mask prompt into the SAM mask encoder
+ if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
+ input_masks = F.interpolate(
+ input_masks.float(),
+ size=self.prompt_encoder.mask_input_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ ).to(input_masks.dtype)
+
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+ low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder(
+ image_embeddings=image_embeddings[-1],
+ image_positional_embeddings=image_positional_embeddings,
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ high_resolution_features=image_embeddings[:-1],
+ attention_similarity=attention_similarity,
+ target_embedding=target_embedding,
+ **kwargs,
+ )
+
+ return Sam2ImageSegmentationOutput(
+ iou_scores=iou_scores,
+ pred_masks=low_res_multimasks,
+ object_score_logits=object_score_logits,
+ image_embeddings=image_embeddings,
+ vision_hidden_states=vision_hidden_states,
+ vision_attentions=vision_attentions,
+ )
+
+
+__all__ = [
+ "Sam2Model",
+ "Sam2VisionModel",
+ "Sam2PreTrainedModel",
+ "Sam2ImageProcessorFast",
+ "Sam2HieraDetModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/processing_sam2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/processing_sam2.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f147aab8dfadaeaa6790246e2275a76b0dd67be
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2/processing_sam2.py
@@ -0,0 +1,526 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for SAM2.
+"""
+
+from copy import deepcopy
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_utils import ImageInput
+from ...processing_utils import ProcessorMixin
+from ...tokenization_utils_base import BatchEncoding
+from ...utils import TensorType, is_torch_available, logging
+from ...utils.import_utils import requires
+
+
+logger = logging.get_logger(__name__)
+
+if is_torch_available():
+ import torch
+
+
+@requires(backends=("torch",))
+class Sam2Processor(ProcessorMixin):
+ r"""
+ Constructs a SAM2 processor which wraps a SAM2 image processor and an 2D points & Bounding boxes processor into a
+ single processor.
+
+ [`Sam2Processor`] offers all the functionalities of [`Sam2ImageProcessorFast`] and [`Sam2VideoProcessor`]. See the docstring of
+ [`~Sam2ImageProcessorFast.__call__`] and [`~Sam2VideoProcessor.__call__`] for more information.
+
+ Args:
+ image_processor (`Sam2ImageProcessorFast`):
+ An instance of [`Sam2ImageProcessorFast`].
+ target_size (`int`, *optional*):
+ The target size (target_size, target_size) to which the image will be resized.
+ point_pad_value (`int`, *optional*, defaults to -10):
+ The value used for padding input points.
+ """
+
+ attributes = ["image_processor"]
+ image_processor_class = "Sam2ImageProcessorFast"
+
+ def __init__(self, image_processor, target_size: Optional[int] = None, point_pad_value: int = -10, **kwargs):
+ super().__init__(image_processor, **kwargs)
+ self.point_pad_value = point_pad_value
+ self.target_size = target_size if target_size is not None else self.image_processor.size["height"]
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ segmentation_maps: Optional[ImageInput] = None,
+ input_points: Optional[Union[list[list[list[list[float]]]], torch.Tensor]] = None,
+ input_labels: Optional[Union[list[list[list[int]]], torch.Tensor]] = None,
+ input_boxes: Optional[Union[list[list[list[float]]], torch.Tensor]] = None,
+ original_sizes: Optional[Union[list[list[float]], torch.Tensor]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs,
+ ) -> BatchEncoding:
+ r"""
+ This method uses [`Sam2ImageProcessorFast.__call__`] method to prepare image(s) for the model. It also prepares 2D
+ points and bounding boxes for the model if they are provided.
+
+ Args:
+ images (`ImageInput`, *optional*):
+ The image(s) to process.
+ segmentation_maps (`ImageInput`, *optional*):
+ The segmentation maps to process.
+ input_points (`list[list[list[list[float]]]]`, `torch.Tensor`, *optional*):
+ The points to add to the frame.
+ input_labels (`list[list[list[int]]]`, `torch.Tensor`, *optional*):
+ The labels for the points.
+ input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*):
+ The bounding boxes to add to the frame.
+ original_sizes (`list[list[float]]`, `torch.Tensor`, *optional*):
+ The original sizes of the images.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return.
+ **kwargs:
+ Additional keyword arguments to pass to the image processor.
+
+ Returns:
+ A [`BatchEncoding`] with the following fields:
+ - `pixel_values` (`torch.Tensor`): The processed image(s).
+ - `original_sizes` (`list[list[float]]`): The original sizes of the images.
+ - `reshaped_input_sizes` (`torch.Tensor`): The reshaped input sizes of the images.
+ - `labels` (`torch.Tensor`): The processed segmentation maps (if provided).
+ - `input_points` (`torch.Tensor`): The processed points.
+ - `input_labels` (`torch.Tensor`): The processed labels.
+ - `input_boxes` (`torch.Tensor`): The processed bounding boxes.
+ """
+ if images is not None:
+ encoding_image_processor = self.image_processor(
+ images,
+ segmentation_maps=segmentation_maps,
+ return_tensors=return_tensors,
+ **kwargs,
+ )
+ elif original_sizes is not None:
+ if isinstance(original_sizes, torch.Tensor):
+ original_sizes = original_sizes.cpu().tolist()
+ encoding_image_processor = BatchEncoding({"original_sizes": original_sizes}, tensor_type=return_tensors)
+ else:
+ raise ValueError("Either images or original_sizes must be provided")
+
+ # pop arguments that are not used in the forward but used nevertheless
+ original_sizes = encoding_image_processor["original_sizes"]
+ # Check original_sizes is of length 1 or len(images)
+ if images is not None and len(original_sizes) != 1 and len(original_sizes) != len(images):
+ raise ValueError(
+ "original_sizes must be of length 1 or len(images). If you are passing a single image, you must pass a single original_size."
+ )
+
+ # Process input points, labels, and boxes if provided
+ if input_points is not None or input_labels is not None or input_boxes is not None:
+ # Validate and convert inputs to standardized format
+ processed_points = self._validate_single_input(
+ input_points,
+ expected_depth=4,
+ input_name="points",
+ expected_format="[image level, object level, point level, point coordinates]",
+ expected_coord_size=2,
+ )
+ processed_labels = self._validate_single_input(
+ input_labels,
+ expected_depth=3,
+ input_name="labels",
+ expected_format="[image level, object level, point level]",
+ )
+ processed_boxes = self._validate_single_input(
+ input_boxes,
+ expected_depth=3,
+ input_name="boxes",
+ expected_format="[image level, box level, box coordinates]",
+ expected_coord_size=4,
+ )
+
+ # Get padding requirements for all inputs
+ if processed_points is not None:
+ points_max_dims = self._get_nested_dimensions(processed_points)[:3]
+ if processed_labels is not None:
+ labels_max_dims = self._get_nested_dimensions(processed_labels)[:3]
+ if processed_boxes is not None:
+ boxes_max_dims = self._get_nested_dimensions(processed_boxes)[:2]
+
+ # Ensure points and labels have consistent dimensions
+ if processed_points is not None and processed_labels is not None:
+ if points_max_dims != labels_max_dims:
+ raise ValueError(
+ "Input points and labels have inconsistent dimensions. Please ensure they have the same dimensions."
+ )
+
+ # Check that boxes don't need padding (model limitation)
+ if processed_boxes is not None and len(processed_boxes) >= 2:
+ if any(len(img_boxes) < boxes_max_dims[1] for img_boxes in processed_boxes):
+ raise ValueError(
+ "Input boxes have inconsistent dimensions that would require padding, "
+ "but boxes cannot be padded due to model limitations. "
+ "Please ensure all images have the same number of boxes."
+ )
+
+ # Pad and normalize all inputs to final tensor format
+ if processed_points is not None:
+ padded_points = self._pad_nested_list(processed_points, points_max_dims + [2])
+ final_points = torch.tensor(padded_points, dtype=torch.float32)
+ self._normalize_tensor_coordinates(final_points, original_sizes, preserve_padding=True)
+ encoding_image_processor.update({"input_points": final_points})
+
+ if processed_labels is not None:
+ padded_labels = self._pad_nested_list(processed_labels, labels_max_dims)
+ final_labels = torch.tensor(padded_labels, dtype=torch.int64)
+ encoding_image_processor.update({"input_labels": final_labels})
+
+ if processed_boxes is not None:
+ final_boxes = torch.tensor(processed_boxes, dtype=torch.float32)
+ self._normalize_tensor_coordinates(final_boxes, original_sizes, is_bounding_box=True)
+ encoding_image_processor.update({"input_boxes": final_boxes})
+
+ return encoding_image_processor
+
+ def _normalize_coordinates(
+ self, target_size: int, coords: "torch.Tensor", original_size, is_bounding_box=False
+ ) -> "torch.Tensor":
+ """
+ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format.
+
+ Args:
+ target_size (`int`):
+ The target size of the image.
+ coords (`torch.Tensor`):
+ The coordinates to be normalized.
+ original_size (`tuple`):
+ The original size of the image.
+ is_bounding_box (`bool`, *optional*, defaults to `False`):
+ Whether the coordinates are bounding boxes.
+ """
+ old_h, old_w = original_size
+ new_h, new_w = target_size, target_size
+ coords = deepcopy(coords).float()
+
+ if is_bounding_box:
+ coords = coords.reshape(-1, 2, 2)
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+
+ if is_bounding_box:
+ coords = coords.reshape(-1, 4)
+
+ return coords
+
+ def _convert_to_nested_list(self, data, expected_depth, current_depth=0):
+ """
+ Recursively convert various input formats (tensors, numpy arrays, lists) to nested lists.
+
+ Args:
+ data: Input data in any format
+ expected_depth: Expected nesting depth
+ current_depth: Current depth in recursion
+
+ Returns:
+ Nested list representation of the data
+ """
+ if data is None:
+ return None
+
+ # Convert tensor/numpy to list if we're at a leaf level or if it's a multi-dimensional array
+ if isinstance(data, torch.Tensor): # PyTorch tensor
+ if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small tensor
+ return data.numpy().tolist()
+ else:
+ return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data]
+ elif isinstance(data, np.ndarray): # NumPy array
+ if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small array
+ return data.tolist()
+ else:
+ return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data]
+ elif isinstance(data, list):
+ if current_depth == expected_depth:
+ # We've reached the expected depth, return as is
+ return data
+ else:
+ # Continue recursion
+ return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data]
+ elif isinstance(data, (int, float)):
+ return data
+ else:
+ raise ValueError(f"Unsupported data type: {type(data)}")
+
+ def _get_nested_dimensions(self, nested_list, max_dims=None):
+ """
+ Get the maximum dimensions at each level of nesting.
+
+ Args:
+ nested_list (`list`):
+ Nested list structure.
+ max_dims (`list`, *optional*):
+ Current maximum dimensions (for recursion).
+
+ Returns:
+ `list`: A list of maximum dimensions for each nesting level.
+ """
+ if max_dims is None:
+ max_dims = []
+
+ if not isinstance(nested_list, list):
+ return max_dims
+
+ if len(max_dims) == 0:
+ max_dims.append(len(nested_list))
+ else:
+ max_dims[0] = max(max_dims[0], len(nested_list))
+
+ if len(nested_list) > 0:
+ for item in nested_list:
+ if isinstance(item, list):
+ sub_dims = self._get_nested_dimensions(item)
+ # Merge sub_dims into max_dims
+ for i, dim in enumerate(sub_dims):
+ if i + 1 >= len(max_dims):
+ max_dims.append(dim)
+ else:
+ max_dims[i + 1] = max(max_dims[i + 1], dim)
+
+ return max_dims
+
+ def _pad_nested_list(self, nested_list, target_dims, current_level=0, pad_value=None):
+ """
+ Recursively pad a nested list to match target dimensions.
+
+ Args:
+ nested_list (`list`):
+ Nested list to pad.
+ target_dims (`list`):
+ Target dimensions for each level.
+ current_level (`int`, *optional*, defaults to 0):
+ Current nesting level.
+ pad_value (`int`, *optional*):
+ Value to use for padding.
+
+ Returns:
+ `list`: The padded nested list.
+ """
+ if pad_value is None:
+ pad_value = self.point_pad_value
+
+ if current_level >= len(target_dims):
+ return nested_list
+
+ # Ensure we have a list
+ if not isinstance(nested_list, list):
+ nested_list = [nested_list]
+
+ # Pad current level
+ current_size = len(nested_list)
+ target_size = target_dims[current_level]
+
+ # Pad with appropriate values
+ if current_level == len(target_dims) - 1:
+ # At the coordinate level, pad with pad_value
+ nested_list.extend([pad_value] * (target_size - current_size))
+ else:
+ # At higher levels, pad with nested structures
+ if current_size > 0:
+ # Create appropriately sized template
+ if current_level < len(target_dims) - 2:
+ # For non-coordinate levels, create empty nested structure
+ template_dims = target_dims[current_level + 1 :]
+ template = self._create_empty_nested_structure(template_dims, pad_value)
+ else:
+ # For coordinate level, create list of pad_values
+ template = [pad_value] * target_dims[current_level + 1]
+
+ nested_list.extend([deepcopy(template) for _ in range(target_size - current_size)])
+ else:
+ # Create from scratch
+ template_dims = target_dims[current_level + 1 :]
+ template = self._create_empty_nested_structure(template_dims, pad_value)
+ nested_list.extend([deepcopy(template) for _ in range(target_size)])
+
+ # Recursively pad sublists
+ if current_level < len(target_dims) - 1:
+ for i in range(len(nested_list)):
+ if isinstance(nested_list[i], list):
+ nested_list[i] = self._pad_nested_list(nested_list[i], target_dims, current_level + 1, pad_value)
+
+ return nested_list
+
+ def _create_empty_nested_structure(self, dims, pad_value):
+ """
+ Create an empty nested structure with given dimensions filled with pad_value.
+
+ Args:
+ dims (`list`):
+ The dimensions of the nested structure.
+ pad_value (`int`):
+ The value to fill the structure with.
+ """
+ if len(dims) == 1:
+ return [pad_value] * dims[0]
+ else:
+ return [self._create_empty_nested_structure(dims[1:], pad_value) for _ in range(dims[0])]
+
+ def _get_nesting_level(self, input_list):
+ """
+ Get the nesting level of a list structure.
+
+ Args:
+ input_list (`list`):
+ The list to get the nesting level of.
+ """
+ if isinstance(input_list, list):
+ if len(input_list) == 0:
+ return 1
+ return 1 + self._get_nesting_level(input_list[0])
+ elif isinstance(input_list, (np.ndarray, torch.Tensor)):
+ # For arrays/tensors, the nesting level is the number of dimensions
+ return len(input_list.shape)
+ return 0
+
+ def _validate_single_input(
+ self,
+ data: Union[torch.Tensor, np.ndarray, list],
+ expected_depth: int,
+ input_name: str,
+ expected_format: str,
+ expected_coord_size: Optional[int] = None,
+ ) -> list:
+ """
+ Validate a single input by ensuring proper nesting and raising an error if the input is not valid.
+
+ Args:
+ data (`torch.Tensor`, `np.ndarray`, or `list`):
+ Input data to process.
+ expected_depth (`int`):
+ Expected nesting depth.
+ input_name (`str`):
+ Name of the input for error messages.
+ expected_format (`str`):
+ The expected format of the input.
+ expected_coord_size (`int`, *optional*):
+ Expected coordinate size (2 for points, 4 for boxes, None for labels).
+ .
+ """
+ if data is None:
+ return None
+
+ # Handle tensors and numpy arrays first
+ if isinstance(data, (torch.Tensor, np.ndarray)):
+ # For tensors/arrays, we can directly check the number of dimensions
+ if data.ndim != expected_depth:
+ raise ValueError(
+ f"Input {input_name} must be a tensor/array with {expected_depth} dimensions. The expected nesting format is {expected_format}. Got {data.ndim} dimensions."
+ )
+ elif expected_coord_size is not None:
+ if data.shape[-1] != expected_coord_size:
+ raise ValueError(
+ f"Input {input_name} must be a tensor/array with {expected_coord_size} as the last dimension, got {data.shape[-1]}."
+ )
+ return self._convert_to_nested_list(data, expected_depth)
+
+ # Handle nested lists
+ if isinstance(data, list):
+ current_depth = self._get_nesting_level(data)
+ if current_depth != expected_depth:
+ raise ValueError(
+ f"Input {input_name} must be a nested list with {expected_depth} levels. The expected nesting format is {expected_format}. Got {current_depth} levels."
+ )
+ return self._convert_to_nested_list(data, expected_depth)
+
+ def _normalize_tensor_coordinates(self, tensor, original_sizes, is_bounding_box=False, preserve_padding=False):
+ """
+ Helper method to normalize coordinates in a tensor across multiple images.
+
+ Args:
+ tensor (`torch.Tensor`):
+ Input tensor with coordinates.
+ original_sizes (`list`):
+ Original image sizes.
+ is_bounding_box (`bool`, *optional*, defaults to `False`):
+ Whether coordinates are bounding boxes.
+ preserve_padding (`bool`, *optional*, defaults to `False`):
+ Whether to preserve padding values (for points).
+ """
+ if preserve_padding:
+ # For points: avoid normalizing pad values
+ mask = tensor != self.point_pad_value
+ coord_mask = mask.all(dim=-1, keepdim=True)
+
+ for img_idx in range(len(original_sizes)):
+ if img_idx < tensor.shape[0]:
+ original_size = original_sizes[img_idx] if img_idx < len(original_sizes) else original_sizes[0]
+ normalized_coords = self._normalize_coordinates(
+ self.target_size, tensor[img_idx], original_size, is_bounding_box=is_bounding_box
+ )
+
+ if preserve_padding:
+ # Only update non-padded values
+ img_mask = coord_mask[img_idx]
+ tensor[img_idx] = torch.where(
+ img_mask.expand_as(tensor[img_idx]), normalized_coords, tensor[img_idx]
+ )
+ else:
+ tensor[img_idx] = normalized_coords
+
+ def post_process_masks(
+ self,
+ masks,
+ original_sizes,
+ mask_threshold=0.0,
+ binarize=True,
+ max_hole_area=0.0,
+ max_sprinkle_area=0.0,
+ apply_non_overlapping_constraints=False,
+ **kwargs,
+ ):
+ """
+ Remove padding and upscale masks to the original image size.
+
+ Args:
+ masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
+ Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
+ original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
+ The original sizes of each image before it was resized to the model's expected input shape, in (height,
+ width) format.
+ mask_threshold (`float`, *optional*, defaults to 0.0):
+ Threshold for binarization and post-processing operations.
+ binarize (`bool`, *optional*, defaults to `True`):
+ Whether to binarize the masks.
+ max_hole_area (`float`, *optional*, defaults to 0.0):
+ The maximum area of a hole to fill.
+ max_sprinkle_area (`float`, *optional*, defaults to 0.0):
+ The maximum area of a sprinkle to fill.
+ apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`):
+ Whether to apply non-overlapping constraints to the masks.
+
+ Returns:
+ (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
+ is given by original_size.
+ """
+ return self.image_processor.post_process_masks(
+ masks,
+ original_sizes,
+ mask_threshold,
+ binarize,
+ max_hole_area,
+ max_sprinkle_area,
+ apply_non_overlapping_constraints,
+ **kwargs,
+ )
+
+
+__all__ = ["Sam2Processor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2_video/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2_video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..565e8bcaf4d0d327ff1dbf21343008b24e0e844b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2_video/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_sam2_video import *
+ from .modeling_sam2_video import *
+ from .processing_sam2_video import *
+ from .video_processing_sam2_video import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2_video/modeling_sam2_video.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2_video/modeling_sam2_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..79d5b015f889933acd7c569deffcb806e5e04985
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2_video/modeling_sam2_video.py
@@ -0,0 +1,2640 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/sam2_video/modular_sam2_video.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_sam2_video.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections import OrderedDict
+from collections.abc import Iterator
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from tqdm import tqdm
+
+from ...activations import ACT2FN
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import compile_compatible_method_lru_cache
+from ...utils import ModelOutput, auto_docstring
+from ...utils.generic import OutputRecorder, TransformersKwargs
+from ..auto import AutoModel
+from .configuration_sam2_video import Sam2VideoConfig, Sam2VideoMaskDecoderConfig, Sam2VideoPromptEncoderConfig
+
+
+class Sam2VideoInferenceCache:
+ """Cache for vision features and model constants."""
+
+ def __init__(
+ self,
+ inference_device: Union[torch.device, str] = "cpu",
+ inference_state_device: Union[torch.device, str] = "cpu",
+ max_vision_features_cache_size: int = 1,
+ ):
+ self.inference_device = inference_device
+ self.inference_state_device = inference_state_device
+ self.max_vision_features_cache_size = max_vision_features_cache_size
+
+ self._vision_features = {}
+
+ def cache_vision_features(self, frame_idx: int, features: dict):
+ """Cache vision features with automatic device management."""
+ cached = {}
+ if len(self._vision_features) >= self.max_vision_features_cache_size:
+ # remove the oldest frame
+ self._vision_features.pop(min(self._vision_features.keys()))
+
+ for key, value in features.items():
+ if isinstance(value, torch.Tensor):
+ cached[key] = value.to(self.inference_state_device, non_blocking=True)
+ elif isinstance(value, (list, tuple)) and value and isinstance(value[0], torch.Tensor):
+ cached[key] = [v.to(self.inference_state_device, non_blocking=True) for v in value]
+ else:
+ cached[key] = value
+ self._vision_features[frame_idx] = cached
+
+ def get_vision_features(self, frame_idx: int) -> Optional[dict]:
+ """Get cached vision features, automatically moved to inference device."""
+ if frame_idx not in self._vision_features:
+ return None
+
+ cached = self._vision_features[frame_idx]
+ moved = {}
+ for key, value in cached.items():
+ if isinstance(value, torch.Tensor):
+ moved[key] = value.to(self.inference_device, non_blocking=True)
+ elif isinstance(value, (list, tuple)) and value and isinstance(value[0], torch.Tensor):
+ moved[key] = [v.to(self.inference_device, non_blocking=True) for v in value]
+ else:
+ moved[key] = value
+ return moved
+
+ def clear_all(self):
+ """Clear all cached data."""
+ self._vision_features.clear()
+
+
+class Sam2VideoInferenceSession:
+ r"""
+ Manages video inference session parameters, state and cache.
+
+ Args:
+ video (`torch.FloatTensor`, *optional*):
+ The video to process. No need to provide when streaming.
+ video_height (`int`, *optional*):
+ The height of the video.
+ video_width (`int`, *optional*):
+ The width of the video.
+ inference_device (`torch.device`, *optional*, defaults to `"cpu"`):
+ The device to use for inference.
+ inference_state_device (`torch.device`, *optional*, defaults to `"cpu"`):
+ The device to store the inference state on.
+ video_storage_device (`torch.device`, *optional*, defaults to `"cpu"`):
+ The device to store the video on.
+ dtype (`torch.dtype`, *optional*, defaults to `"float32"`):
+ The dtype to use for the video.
+ max_vision_features_cache_size (`int`, *optional*, defaults to 1):
+ The maximum number of vision features to cache.
+ """
+
+ def __init__(
+ self,
+ video: Optional[torch.FloatTensor] = None,
+ video_height: Optional[int] = None,
+ video_width: Optional[int] = None,
+ inference_device: Union[torch.device, str] = "cpu",
+ inference_state_device: Union[torch.device, str] = "cpu",
+ video_storage_device: Union[torch.device, str] = "cpu",
+ dtype: Union[torch.dtype, str] = "float32",
+ max_vision_features_cache_size: int = 1,
+ ):
+ # store as a dictionary to avoid double memory allocation with torch.cat when adding new frames
+ self.processed_frames = (
+ dict(enumerate(video.to(video_storage_device, dtype=dtype))) if video is not None else None
+ )
+ self.video_height = video_height
+ self.video_width = video_width
+
+ self.inference_device = inference_device
+ self.inference_state_device = inference_state_device
+ self.video_storage_device = video_storage_device
+ self.dtype = dtype
+ self.max_vision_features_cache_size = max_vision_features_cache_size
+
+ # Cache for computed features
+ self.cache = Sam2VideoInferenceCache(
+ inference_device=self.inference_device,
+ inference_state_device=self.inference_state_device,
+ max_vision_features_cache_size=self.max_vision_features_cache_size,
+ )
+
+ # Persistent object tracking state
+ self._obj_id_to_idx = OrderedDict()
+ self._obj_idx_to_id = OrderedDict()
+ self.obj_ids = []
+
+ # Persistent user inputs
+ self.point_inputs_per_obj = {}
+ self.mask_inputs_per_obj = {}
+
+ # Persistent model outputs/history
+ self.output_dict_per_obj = {}
+ self.frames_tracked_per_obj = {}
+
+ # Session state flags
+ self.obj_with_new_inputs = []
+
+ @property
+ def num_frames(self) -> Optional[int]:
+ return len(self.processed_frames) if self.processed_frames is not None else None
+
+ # Object management
+ def obj_id_to_idx(self, obj_id: int) -> int:
+ """Map object ID to index, creating new entry if needed."""
+ obj_idx = self._obj_id_to_idx.get(obj_id, None)
+ if obj_idx is not None:
+ return obj_idx
+
+ obj_idx = len(self._obj_id_to_idx)
+ self._obj_id_to_idx[obj_id] = obj_idx
+ self._obj_idx_to_id[obj_idx] = obj_id
+ self.obj_ids = list(self._obj_id_to_idx)
+
+ self.point_inputs_per_obj[obj_idx] = {}
+ self.mask_inputs_per_obj[obj_idx] = {}
+ self.output_dict_per_obj[obj_idx] = {
+ "cond_frame_outputs": {},
+ "non_cond_frame_outputs": {},
+ }
+ self.frames_tracked_per_obj[obj_idx] = {}
+
+ return obj_idx
+
+ # Video Inference specific functions
+ def obj_idx_to_id(self, obj_idx: int) -> int:
+ """Map model-side object index to client-side object id."""
+ return self._obj_idx_to_id[obj_idx]
+
+ def get_obj_num(self) -> int:
+ """Get the total number of unique object ids received so far in this session."""
+ return len(self._obj_idx_to_id)
+
+ # Input management with device handling
+ def add_point_inputs(self, obj_idx: int, frame_idx: int, inputs: dict):
+ """Add point inputs with automatic device placement."""
+ device_inputs = {}
+ for key, value in inputs.items():
+ if isinstance(value, torch.Tensor):
+ device_inputs[key] = value.to(self.inference_device, non_blocking=True)
+ else:
+ device_inputs[key] = value
+ self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs
+
+ def remove_point_inputs(self, obj_idx: int, frame_idx: int):
+ """Remove point inputs."""
+ self.point_inputs_per_obj[obj_idx].pop(frame_idx, None)
+
+ def add_mask_inputs(self, obj_idx: int, frame_idx: int, inputs: torch.Tensor):
+ """Add mask inputs with automatic device placement."""
+ self.mask_inputs_per_obj[obj_idx][frame_idx] = inputs.to(
+ self.inference_device, dtype=self.dtype, non_blocking=True
+ )
+
+ def remove_mask_inputs(self, obj_idx: int, frame_idx: int):
+ """Remove mask inputs."""
+ self.mask_inputs_per_obj[obj_idx].pop(frame_idx, None)
+
+ # Output management with smart device placement
+ def store_output(
+ self,
+ obj_idx: int,
+ frame_idx: int,
+ output_key: Optional[str] = None,
+ output_value: Optional[Union[torch.Tensor, dict]] = None,
+ is_conditioning_frame: bool = True,
+ ):
+ """
+ Store output with smart device management.
+ If output_key is None, the output is stored as a dictionary.
+
+ Args:
+ obj_idx (int): The index of the object.
+ frame_idx (int): The index of the frame.
+ output_key (Optional[str]): The key of the output. If None, the output is stored as a dictionary.
+ output_value (Optional[Union[torch.Tensor, dict]]): The value of the output.
+ is_conditioning_frame (bool): Whether the output is for a conditioning frame.
+ """
+ storage_key = "cond_frame_outputs" if is_conditioning_frame else "non_cond_frame_outputs"
+
+ if output_key is None and isinstance(output_value, dict):
+ self.output_dict_per_obj[obj_idx][storage_key][frame_idx] = {}
+ for key, value in output_value.items():
+ self.store_output(obj_idx, frame_idx, key, value, is_conditioning_frame)
+ return
+
+ # Device placement: small tensors stay on inference device, large ones go to inference state device
+ if output_key in ["object_pointer", "object_score_logits"]: # Small tensors
+ self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value
+ elif isinstance(output_value, torch.Tensor): # Large tensors like masks, features
+ self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value.to(
+ self.inference_state_device, non_blocking=True
+ )
+ else:
+ self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value
+
+ def get_output(
+ self,
+ obj_idx: int,
+ frame_idx: int,
+ output_key: str,
+ is_conditioning_frame: bool = True,
+ ):
+ """
+ Get output with smart device management.
+
+ Args:
+ obj_idx (int): The index of the object.
+ frame_idx (int): The index of the frame.
+ output_key (str): The key of the output.
+ is_conditioning_frame (bool): Whether the output is for a conditioning frame.
+ """
+ storage_key = "cond_frame_outputs" if is_conditioning_frame else "non_cond_frame_outputs"
+ out = self.output_dict_per_obj[obj_idx][storage_key].get(frame_idx, None)
+ # move to inference device if needed
+ if out is None:
+ return None
+ value = out[output_key]
+ if isinstance(value, torch.Tensor):
+ value = value.to(self.inference_device, non_blocking=True)
+ return value
+
+ # Video frame management
+ def add_new_frame(self, pixel_values: torch.Tensor, frame_idx: Optional[int] = None) -> int:
+ """Add new frame with automatic device placement."""
+ pixel_values = pixel_values.to(self.video_storage_device, dtype=self.dtype, non_blocking=True)
+ if pixel_values.dim() == 4:
+ pixel_values = pixel_values.squeeze(0)
+
+ if frame_idx is None:
+ frame_idx = len(self.processed_frames) if self.processed_frames is not None else 0
+
+ if self.processed_frames is None:
+ self.processed_frames = {frame_idx: pixel_values}
+ else:
+ self.processed_frames[frame_idx] = pixel_values
+
+ return frame_idx
+
+ def get_frame(self, frame_idx: int) -> torch.Tensor:
+ """Get frame from video."""
+ return self.processed_frames[frame_idx].to(self.inference_device, non_blocking=True)
+
+ def reset_tracking_data(self):
+ """Reset tracking data but keep cache."""
+ self._obj_id_to_idx.clear()
+ self._obj_idx_to_id.clear()
+ self.obj_ids.clear()
+ self.point_inputs_per_obj.clear()
+ self.mask_inputs_per_obj.clear()
+ self.output_dict_per_obj.clear()
+ self.frames_tracked_per_obj.clear()
+ self.obj_with_new_inputs = []
+ # Note: cache and video data are preserved
+
+ def reset_inference_session(self):
+ """Reset tracking data and cache."""
+ self._obj_id_to_idx.clear()
+ self._obj_idx_to_id.clear()
+ self.obj_ids.clear()
+ self.point_inputs_per_obj.clear()
+ self.mask_inputs_per_obj.clear()
+ self.output_dict_per_obj.clear()
+ self.frames_tracked_per_obj.clear()
+ self.obj_with_new_inputs = []
+ self.cache.clear_all()
+
+
+class Sam2VideoLayerNorm(nn.LayerNorm):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
+ super().__init__(normalized_shape, eps=eps, **kwargs)
+ if data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError(f"Unsupported data format: {data_format}")
+ self.data_format = data_format
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
+ """
+ if self.data_format == "channels_first":
+ features = features.permute(0, 2, 3, 1)
+ features = super().forward(features)
+ features = features.permute(0, 3, 1, 2)
+ else:
+ features = super().forward(features)
+ return features
+
+
+# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
+class Sam2VideoPositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
+ need paper, generalized to work on images.
+ """
+
+ def __init__(
+ self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
+ ):
+ super().__init__()
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ self.scale = 2 * math.pi if scale is None else scale
+
+ @compile_compatible_method_lru_cache(maxsize=1)
+ def forward(
+ self,
+ shape: torch.Size,
+ device: Union[torch.device, str],
+ dtype: torch.dtype,
+ mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ if mask is None:
+ mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
+ not_mask = (~mask).to(dtype)
+ y_embed = not_mask.cumsum(1)
+ x_embed = not_mask.cumsum(2)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Sam2VideoAttention(nn.Module):
+ """
+ SAM2_VIDEO's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
+ values.
+ """
+
+ def __init__(self, config, downsample_rate=None):
+ super().__init__()
+ downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.internal_dim = config.hidden_size // downsample_rate
+ self.num_attention_heads = config.num_attention_heads
+ self.head_dim = self.internal_dim // config.num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_similarity: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # Input projections
+ batch_size, point_batch_size = query.shape[:2]
+ new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
+
+ query = self.q_proj(query).view(*new_shape).transpose(1, 2)
+ key = self.k_proj(key).view(*new_shape).transpose(1, 2)
+ value = self.v_proj(value).view(*new_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=attention_similarity,
+ dropout=0.0,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(
+ batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
+ ).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class Sam2VideoTwoWayAttentionBlock(nn.Module):
+ def __init__(self, config: Sam2VideoMaskDecoderConfig, skip_first_layer_pe: bool = False):
+ """
+ A transformer block with four layers:
+ (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
+ sparse inputs (4) cross attention of dense inputs -> sparse inputs
+
+ Arguments:
+ config (`Sam2VideoMaskDecoderConfig`):
+ The configuration file used to instantiate the block
+ attention_downsample_rate (*optionalk*, int, defaults to 2):
+ The downsample ratio of the block used to reduce the inner dim of the attention.
+ skip_first_layer_pe (*optional*, bool, defaults to `False`):
+ Whether or not to skip the addition of the query_point_embedding on the first layer.
+ """
+ super().__init__()
+ self.self_attn = Sam2VideoAttention(config, downsample_rate=1)
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size)
+
+ self.cross_attn_token_to_image = Sam2VideoAttention(config)
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size)
+
+ self.mlp = Sam2VideoFeedForward(
+ config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers
+ )
+ self.layer_norm3 = nn.LayerNorm(config.hidden_size)
+
+ self.layer_norm4 = nn.LayerNorm(config.hidden_size)
+ self.cross_attn_image_to_token = Sam2VideoAttention(config)
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(
+ self,
+ queries: Tensor,
+ keys: Tensor,
+ query_point_embedding: Tensor,
+ key_point_embedding: Tensor,
+ attention_similarity: Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries, _ = self.self_attn(query=queries, key=queries, value=queries)
+ else:
+ query = queries + query_point_embedding
+ attn_out, _ = self.self_attn(query=query, key=query, value=queries)
+ queries = queries + attn_out
+ queries = self.layer_norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ query = queries + query_point_embedding
+ key = keys + key_point_embedding
+
+ attn_out, _ = self.cross_attn_token_to_image(
+ query=query, key=key, value=keys, attention_similarity=attention_similarity
+ )
+ queries = queries + attn_out
+
+ queries = self.layer_norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.layer_norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ query = queries + query_point_embedding
+ key = keys + key_point_embedding
+
+ attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
+ keys = keys + attn_out
+
+ keys = self.layer_norm4(keys)
+ return queries, keys, attn_out
+
+
+class Sam2VideoFeedForward(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ activation: str = "relu",
+ sigmoid_output: bool = False,
+ ):
+ super().__init__()
+ self.num_layers = num_layers
+ self.activation = ACT2FN[activation]
+ self.proj_in = nn.Linear(input_dim, hidden_dim)
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
+ self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
+ self.sigmoid_output = sigmoid_output
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ for layer in self.layers:
+ hidden_states = self.activation(layer(hidden_states))
+
+ hidden_states = self.proj_out(hidden_states)
+ if self.sigmoid_output:
+ hidden_states = F.sigmoid(hidden_states)
+ return hidden_states
+
+
+@dataclass
+@auto_docstring(custom_intro="Base class for the Sam2Video model's output.")
+class Sam2VideoImageSegmentationOutput(ModelOutput):
+ r"""
+ iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
+ The Intersection over Union (IoU) scores of the predicted masks.
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
+ The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
+ by the processor to be brought to the original image size.
+ object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
+ Logits for the object score, indicating if an object is present.
+ image_embeddings (`tuple(torch.FloatTensor)`):
+ The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
+ tensor has shape `(batch_size, channels, height, width)`.
+ vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
+ Hidden-states of the vision model at the output of each stage.
+ vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
+ Attentions weights of the vision model.
+ mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
+ Attentions weights of the mask decoder.
+ high_res_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, image_size, image_size)`, *optional*):
+ The predicted masks, upscaled to the original image size. Only used for Sam2VideoModel.
+ object_pointer (`torch.FloatTensor` of shape `(batch_size, point_batch_size, hidden_size)`, *optional*):
+ A tensor representing the object pointer, used for tracking in videos. Only used for Sam2VideoModel.
+ """
+
+ iou_scores: Optional[torch.FloatTensor] = None
+ pred_masks: Optional[torch.FloatTensor] = None
+ object_score_logits: Optional[torch.FloatTensor] = None
+ image_embeddings: tuple[torch.FloatTensor, ...] = None
+ vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+ high_res_masks: Optional[torch.FloatTensor] = None
+ object_pointer: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(custom_intro="Base class for the Sam2 model's output.")
+class Sam2VideoSegmentationOutput(ModelOutput):
+ r"""
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
+ The predicted masks stored at the model's resolution.
+ frame_idx (`int`):
+ The frame index of the video.
+ """
+
+ pred_masks: Optional[torch.FloatTensor] = None
+ frame_idx: Optional[int] = None
+
+
+@auto_docstring
+class Sam2VideoPreTrainedModel(PreTrainedModel):
+ config_class = Sam2VideoConfig
+ base_model_prefix = "sam2_video"
+ main_input_name = "pixel_values"
+ _supports_sdpa = True
+ _supports_flash_attn_2 = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, (nn.LayerNorm, Sam2VideoLayerNorm)):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+ elif isinstance(module, Sam2VideoModel):
+ if module.no_memory_positional_encoding is not None:
+ module.no_memory_positional_encoding.data.zero_()
+ if module.memory_temporal_positional_encoding is not None:
+ module.memory_temporal_positional_encoding.data.zero_()
+ if module.no_object_pointer is not None:
+ module.no_object_pointer.data.zero_()
+ if module.occlusion_spatial_embedding_parameter is not None:
+ module.occlusion_spatial_embedding_parameter.data.zero_()
+ if isinstance(module, Sam2VideoMemoryFuserCXBlock):
+ if module.scale is not None:
+ module.scale.data.zero_()
+
+
+class Sam2VideoVisionRotaryEmbedding(nn.Module):
+ """
+ Vision Rotary Position Embedding for SAM2, following transformers library standards.
+ Supports 2D (axial) rotary embeddings for spatial dimensions.
+ """
+
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__()
+ dim = config.memory_attention_hidden_size // (
+ config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
+ )
+ # Ensure even dimension for proper axial splitting
+ if dim % 4 != 0:
+ raise ValueError("Dimension must be divisible by 4 for axial RoPE")
+ end_x, end_y = config.memory_attention_rope_feat_sizes
+ freqs = 1.0 / (config.memory_attention_rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+
+ # Generate 2D position indices for axial rotary embedding
+ flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
+ x_positions = flattened_indices % end_x
+ y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
+ freqs_x = torch.outer(x_positions, freqs).float()
+ freqs_y = torch.outer(y_positions, freqs).float()
+ inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
+ inv_freq = inv_freq.repeat_interleave(2, dim=-1)
+ # directly register the cos and sin embeddings as we have a fixed feature shape
+ self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
+ self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
+
+ @torch.no_grad()
+ def forward(self) -> tuple[torch.Tensor, torch.Tensor]:
+ # As the feature map size is fixed, we can just return the pre-computed embeddings.
+ return self.rope_embeddings_cos, self.rope_embeddings_sin
+
+
+def rotate_pairwise(x):
+ """
+ pairwise rotation of the hidden dims of the input. Differerent from Llama Half-Tensor Rotation.
+
+ This is an optimized version of the following more explicit implementation:
+ ```python
+ x_rotated = torch.zeros_like(x, dtype=x.dtype, device=x.device)
+ x_rotated[..., ::2] = -x[..., 1::2]
+ x_rotated[..., 1::2] = x[..., ::2]
+ return x_rotated
+ ```
+ """
+ x = x.view(*x.shape[:-1], -1, 2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return x.flatten(start_dim=-2)
+
+
+# TODO: This leads to ~1e-07 max diff and ~1e-09 avg diff for q_embed and k_embed from the original implementation, most likely due to the use of complex tensors in the original implementation.
+def apply_rotary_pos_emb_2d(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ num_k_exclude_rope: int = 0,
+ repeat_freqs_k: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary position embedding to query and key tensors for vision models.
+ Follows the standard transformers library pattern.
+
+ Args:
+ q: Query tensor of shape (..., seq_len, head_dim)
+ k: Key tensor of shape (..., seq_len, head_dim)
+ cos: Cosine position embedding of shape (seq_len, head_dim)
+ sin: Sine position embedding of shape (seq_len, head_dim)
+ repeat_freqs_k: Whether to repeat frequencies for keys (for cross-attention)
+
+ Returns:
+ Rotated (q, k) tensors
+ """
+ k_rot, k_pass = k[..., : k.shape[-2] - num_k_exclude_rope, :], k[..., k.shape[-2] - num_k_exclude_rope :, :]
+ q_embed = q.float() # force upscale to float32 as in the original implementation
+ q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin)
+ if k_rot.shape[-2] == 0:
+ # Handle case where keys might be empty due to dropout
+ return q_embed.type_as(q), torch.cat([k_rot, k_pass], dim=-2)
+
+ # Handle key tensor - may need to repeat frequencies if different sequence length
+ if repeat_freqs_k and k_rot.shape[-2] != q.shape[-2]:
+ # Repeat cos/sin to match key sequence length
+ repeat_factor = k_rot.shape[-2] // q.shape[-2]
+ cos_k = cos.repeat(1, 1, repeat_factor, 1)
+ sin_k = sin.repeat(1, 1, repeat_factor, 1)
+ else:
+ cos_k = cos
+ sin_k = sin
+
+ # Apply rotary embedding to keys
+ k_embed = k_rot.float() # force upscale to float32 as in the original implementation
+ k_embed = (k_embed * cos_k) + (rotate_pairwise(k_embed) * sin_k)
+ # Concatenate back to full shape
+ k_embed = torch.cat([k_embed.type_as(k), k_pass], dim=-2)
+ return q_embed.type_as(q), k_embed
+
+
+class Sam2VideoRoPEAttention(nn.Module):
+ """Attention with rotary position encoding."""
+
+ def __init__(
+ self,
+ config: Sam2VideoConfig,
+ kv_in_dim: Optional[int] = None,
+ rope_k_repeat=False,
+ ):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.memory_attention_hidden_size
+ self.internal_dim = self.hidden_size // config.memory_attention_downsample_rate
+ self.num_attention_heads = config.memory_attention_num_attention_heads
+ self.head_dim = self.internal_dim // config.memory_attention_num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else self.hidden_size
+
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
+
+ self.rope_k_repeat = rope_k_repeat
+ self.dropout_p = config.memory_attention_rope_dropout
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ num_k_exclude_rope: int = 0,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tensor:
+ # Input projections
+ batch_size, point_batch_size = query.shape[:2]
+ new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
+
+ query = self.q_proj(query).view(*new_shape).transpose(1, 2)
+ key = self.k_proj(key).view(*new_shape).transpose(1, 2)
+ value = self.v_proj(value).view(*new_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ # Apply rotary position encoding, excluding some keys if specified
+ query, key = apply_rotary_pos_emb_2d(
+ query, key, cos, sin, repeat_freqs_k=self.rope_k_repeat, num_k_exclude_rope=num_k_exclude_rope
+ )
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=None,
+ dropout=0.0 if not self.training else self.dropout_p,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(
+ batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
+ ).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Sam2VideoMemoryAttentionLayer(nn.Module):
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__()
+ hidden_size = config.memory_attention_hidden_size
+ self.self_attn = Sam2VideoRoPEAttention(config)
+ self.cross_attn_image = Sam2VideoRoPEAttention(config, kv_in_dim=64, rope_k_repeat=True)
+
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(hidden_size, config.memory_attention_feed_forward_hidden_size)
+ self.dropout = nn.Dropout(config.memory_attention_dropout)
+ self.linear2 = nn.Linear(config.memory_attention_feed_forward_hidden_size, hidden_size)
+
+ self.layer_norm1 = nn.LayerNorm(hidden_size)
+ self.layer_norm2 = nn.LayerNorm(hidden_size)
+ self.layer_norm3 = nn.LayerNorm(hidden_size)
+ self.dropout1 = nn.Dropout(config.memory_attention_dropout)
+ self.dropout2 = nn.Dropout(config.memory_attention_dropout)
+ self.dropout3 = nn.Dropout(config.memory_attention_dropout)
+
+ self.activation = ACT2FN[config.memory_attention_feed_forward_hidden_act]
+
+ def forward(
+ self,
+ queries: Tensor,
+ keys: Tensor,
+ key_point_embedding: Tensor,
+ rope_position_embeddings: tuple[Tensor, Tensor],
+ num_k_exclude_rope: int = 0,
+ ) -> torch.Tensor:
+ # Self-Attention
+ query = self.layer_norm1(queries)
+ query, _ = self.self_attn(query=query, key=query, value=query, position_embeddings=rope_position_embeddings)
+ queries = queries + self.dropout1(query)
+
+ # Cross-Attention
+ query = self.layer_norm2(queries)
+ query, _ = self.cross_attn_image(
+ query=query,
+ key=keys + key_point_embedding,
+ value=keys,
+ position_embeddings=rope_position_embeddings,
+ num_k_exclude_rope=num_k_exclude_rope,
+ )
+ queries = queries + self.dropout2(query)
+ # MLP
+ query = self.layer_norm3(queries)
+ query = self.linear2(self.dropout(self.activation(self.linear1(query))))
+ queries = queries + self.dropout3(query)
+ return queries
+
+
+class Sam2VideoMemoryAttention(nn.Module):
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [Sam2VideoMemoryAttentionLayer(config) for _ in range(config.memory_attention_num_layers)]
+ )
+ self.layer_norm = nn.LayerNorm(config.memory_attention_hidden_size)
+ self.rotary_emb = Sam2VideoVisionRotaryEmbedding(config=config)
+
+ def forward(
+ self,
+ current_vision_features: torch.Tensor,
+ memory: torch.Tensor,
+ current_vision_position_embeddings: Optional[Tensor] = None,
+ memory_posision_embeddings: Optional[Tensor] = None,
+ num_object_pointer_tokens: int = 0,
+ ):
+ """
+ Args:
+ current_vision_features (`torch.FloatTensor`):
+ The current vision features used for self-attention.
+ memory (`torch.FloatTensor`):
+ The memory features used for cross-attention.
+ current_vision_position_embeddings (`torch.FloatTensor`, *optional*):
+ The position embeddings for the current vision features.
+ memory_posision_embeddings (`torch.FloatTensor`, *optional*):
+ The position embeddings for the memory features.
+ num_object_pointer_tokens (`int`, *optional*, defaults to 0):
+ The number of object pointer tokens.
+ """
+ output = current_vision_features
+ if current_vision_position_embeddings is not None:
+ output = output + 0.1 * current_vision_position_embeddings
+
+ # Convert to batch first
+ output = output.transpose(0, 1)
+ memory = memory.transpose(0, 1).unsqueeze(1)
+ memory_posision_embeddings = memory_posision_embeddings.transpose(0, 1).unsqueeze(1)
+ rope_position_embeddings = self.rotary_emb()
+ for layer in self.layers:
+ output = layer(
+ queries=output.unsqueeze(1) if output.ndim == 3 else output,
+ keys=memory,
+ key_point_embedding=memory_posision_embeddings,
+ rope_position_embeddings=rope_position_embeddings,
+ num_k_exclude_rope=num_object_pointer_tokens,
+ )
+
+ normed_output = self.layer_norm(output)
+
+ # Convert back to seq first
+ normed_output = normed_output.transpose(0, 1)
+
+ return normed_output
+
+
+# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
+class Sam2VideoMemoryFuserCXBlock(GradientCheckpointingLayer):
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__()
+ self.depthwise_conv = nn.Conv2d(
+ config.memory_fuser_embed_dim,
+ config.memory_fuser_embed_dim,
+ kernel_size=config.memory_fuser_kernel_size,
+ padding=config.memory_fuser_padding,
+ groups=config.memory_fuser_embed_dim,
+ ) # depthwise conv
+ self.layer_norm = Sam2VideoLayerNorm(config.memory_fuser_embed_dim, eps=1e-6, data_format="channels_first")
+ self.activation = ACT2FN[config.memory_fuser_hidden_act]
+ self.pointwise_conv1 = nn.Linear(
+ config.memory_fuser_embed_dim, config.memory_fuser_intermediate_dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.pointwise_conv2 = nn.Linear(config.memory_fuser_intermediate_dim, config.memory_fuser_embed_dim)
+ self.scale = nn.Parameter(
+ config.memory_fuser_layer_scale_init_value * torch.ones(config.memory_fuser_embed_dim),
+ requires_grad=True,
+ )
+
+ def forward(self, hidden_states):
+ input = hidden_states
+ hidden_states = self.depthwise_conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ hidden_states = self.pointwise_conv1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.pointwise_conv2(hidden_states)
+ hidden_states = self.scale * hidden_states
+ hidden_states = hidden_states.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ hidden_states = input + hidden_states
+ return hidden_states
+
+
+class Sam2VideoMemoryFuser(nn.Module):
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [Sam2VideoMemoryFuserCXBlock(config) for _ in range(config.memory_fuser_num_layers)]
+ )
+
+ def forward(self, hidden_states):
+ # normally hidden_states: (N, C, H, W)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+
+class Sam2VideoMaskDownSamplerLayer(nn.Module):
+ def __init__(self, config: Sam2VideoConfig, in_channels: int, out_channels: int):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=config.mask_downsampler_kernel_size,
+ stride=config.mask_downsampler_stride,
+ padding=config.mask_downsampler_padding,
+ )
+ self.layer_norm = Sam2VideoLayerNorm(out_channels, eps=1e-6, data_format="channels_first")
+ self.activation = ACT2FN[config.mask_downsampler_hidden_act]
+
+ def forward(self, x):
+ return self.activation(self.layer_norm(self.conv(x)))
+
+
+class Sam2VideoMaskDownSampler(nn.Module):
+ """
+ Progressively downsample a mask by total_stride, each time by stride.
+ Note that LayerNorm is applied per *token*, like in ViT.
+
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
+ In the end, we linearly project to embed_dim channels.
+ """
+
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__()
+
+ num_layers = int(math.log2(config.mask_downsampler_total_stride) // math.log2(config.mask_downsampler_stride))
+
+ self.layers = nn.ModuleList()
+ self.activation = ACT2FN[config.mask_downsampler_hidden_act]
+ mask_in_chans, mask_out_chans = 1, 1
+ for _ in range(num_layers):
+ mask_out_chans = mask_in_chans * (config.mask_downsampler_stride**2)
+ self.layers.append(Sam2VideoMaskDownSamplerLayer(config, mask_in_chans, mask_out_chans))
+ mask_in_chans = mask_out_chans
+
+ self.final_conv = nn.Conv2d(mask_out_chans, config.mask_downsampler_embed_dim, kernel_size=1)
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = layer(x)
+ x = self.final_conv(x)
+ return x
+
+
+class Sam2VideoMemoryEncoder(nn.Module):
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__()
+
+ hidden_size = config.memory_encoder_hidden_size
+ output_channels = config.memory_encoder_output_channels
+ self.mask_downsampler = Sam2VideoMaskDownSampler(config)
+ self.feature_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
+ self.memory_fuser = Sam2VideoMemoryFuser(config)
+ self.position_encoding = Sam2VideoPositionEmbeddingSine(num_pos_feats=output_channels // 2, normalize=True)
+ self.projection = nn.Conv2d(hidden_size, output_channels, kernel_size=1)
+
+ def forward(
+ self,
+ vision_features: torch.Tensor,
+ masks: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ ## Process masks
+ masks = self.mask_downsampler(masks)
+ ## Fuse pixel_features and downsampled masks
+
+ vision_features = self.feature_projection(vision_features)
+ vision_features = vision_features + masks
+ vision_features = self.memory_fuser(vision_features)
+ vision_features = self.projection(vision_features)
+
+ vision_pos_enc = self.position_encoding(vision_features.shape, vision_features.device, vision_features.dtype)
+
+ return vision_features, vision_pos_enc
+
+
+@dataclass
+@auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
+class Sam2VideoVisionEncoderOutput(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ fpn_hidden_states (`tuple(torch.FloatTensor)`):
+ Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
+ `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
+ fpn_position_encoding (`tuple(torch.FloatTensor)`):
+ Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
+ `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the
+ model at the output of each stage.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ fpn_hidden_states: Optional[torch.FloatTensor] = None
+ fpn_position_encoding: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+class Sam2VideoPositionalEmbedding(nn.Module):
+ def __init__(self, config: Sam2VideoPromptEncoderConfig):
+ super().__init__()
+ self.scale = config.scale
+ positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
+ self.register_buffer("positional_embedding", positional_embedding)
+
+ def forward(self, input_coords, input_shape=None):
+ """Positionally encode points that are normalized to [0,1]."""
+ coordinates = input_coords.clone()
+
+ if input_shape is not None:
+ coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
+ coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
+ coordinates.to(torch.float32)
+
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coordinates = 2 * coordinates - 1
+ coordinates = coordinates.to(self.positional_embedding.dtype)
+ coordinates = coordinates @ self.positional_embedding
+ coordinates = 2 * np.pi * coordinates
+ # outputs d_1 x ... x d_n x channel shape
+ return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
+
+
+class Sam2VideoMaskEmbedding(nn.Module):
+ def __init__(self, config: Sam2VideoPromptEncoderConfig):
+ super().__init__()
+ self.mask_input_channels = config.mask_input_channels // 4
+ self.activation = ACT2FN[config.hidden_act]
+ self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
+ self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
+ self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
+ self.layer_norm1 = Sam2VideoLayerNorm(
+ self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
+ )
+ self.layer_norm2 = Sam2VideoLayerNorm(
+ self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
+ )
+
+ def forward(self, masks):
+ hidden_states = self.conv1(masks)
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ dense_embeddings = self.conv3(hidden_states)
+ return dense_embeddings
+
+
+class Sam2VideoPromptEncoder(nn.Module):
+ def __init__(self, config: Sam2VideoPromptEncoderConfig):
+ super().__init__()
+ self.shared_embedding = Sam2VideoPositionalEmbedding(config)
+ self.mask_embed = Sam2VideoMaskEmbedding(config)
+ self.no_mask_embed = nn.Embedding(1, config.hidden_size)
+
+ self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
+ self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size)
+ self.input_image_size = config.image_size
+
+ self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size)
+ self.hidden_size = config.hidden_size
+ self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
+
+ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0)
+ labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1)
+ input_shape = (self.input_image_size, self.input_image_size)
+ point_embedding = self.shared_embedding(points, input_shape)
+
+ # torch.where and expanding the labels tensor is required by the ONNX export
+ point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
+
+ # This is required for the ONNX export. The dtype, device need to be explicitly
+ # specified as otherwise torch.onnx.export interprets as double
+ point_embedding = torch.where(
+ labels[..., None] != -10,
+ point_embedding,
+ torch.zeros_like(point_embedding),
+ )
+
+ # Add point embeddings for labels >= 0
+ point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1)
+
+ return point_embedding
+
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts."""
+ boxes += 0.5 # Shift to center of pixel
+ coords = boxes.view(*boxes.shape[:2], 2, 2)
+ # add padding point for consistency with the original implementation
+ coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
+ corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
+ corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
+ corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
+ corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
+ return corner_embedding
+
+ def forward(
+ self,
+ input_points: Optional[tuple[torch.Tensor, torch.Tensor]],
+ input_labels: Optional[torch.Tensor],
+ input_boxes: Optional[torch.Tensor],
+ input_masks: Optional[torch.Tensor],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Embeds different types of prompts, returning both sparse and dense embeddings.
+
+ Args:
+ points (`torch.Tensor`, *optional*):
+ point coordinates and labels to embed.
+ boxes (`torch.Tensor`, *optional*):
+ boxes to embed
+ masks (`torch.Tensor`, *optional*):
+ masks to embed
+ """
+ sparse_embeddings = None
+ batch_size = 1
+ if input_points is not None:
+ batch_size = input_points.shape[0]
+ if input_labels is None:
+ raise ValueError("If points are provided, labels must also be provided.")
+ point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
+ sparse_embeddings = point_embeddings
+ if input_boxes is not None:
+ batch_size = input_boxes.shape[0]
+ box_embeddings = self._embed_boxes(input_boxes)
+ if sparse_embeddings is None:
+ sparse_embeddings = box_embeddings
+ else:
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
+ if input_masks is not None:
+ dense_embeddings = self.mask_embed(input_masks)
+ else:
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+ batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+ )
+
+ return sparse_embeddings, dense_embeddings
+
+
+class Sam2VideoTwoWayTransformer(nn.Module):
+ def __init__(self, config: Sam2VideoMaskDecoderConfig):
+ super().__init__()
+ self.config = config
+
+ self.num_hidden_layers = config.num_hidden_layers
+ self.layers = nn.ModuleList()
+
+ for i in range(self.num_hidden_layers):
+ self.layers.append(Sam2VideoTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
+
+ self.final_attn_token_to_image = Sam2VideoAttention(config)
+ self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ point_embeddings: Tensor,
+ image_embeddings: Tensor,
+ image_positional_embeddings: Tensor,
+ attention_similarity: Tensor,
+ target_embedding=None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutput]:
+ if image_embeddings is None:
+ raise ValueError("You have to specify an image_embedding")
+
+ image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
+ image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
+
+ # Prepare queries
+ queries = point_embeddings
+ keys = image_embeddings
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ if target_embedding is not None:
+ queries += target_embedding
+
+ queries, keys, _ = layer(
+ queries=queries,
+ keys=keys,
+ query_point_embedding=point_embeddings,
+ key_point_embedding=image_positional_embeddings,
+ attention_similarity=attention_similarity,
+ **kwargs,
+ )
+ # Apply the final attention layer from the points to the image
+ query = queries + point_embeddings
+ key = keys + image_positional_embeddings
+
+ attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
+
+ queries = queries + attn_out
+ queries = self.layer_norm_final_attn(queries)
+ return queries, keys
+
+
+class Sam2VideoMaskDecoder(nn.Module):
+ def __init__(self, config: Sam2VideoMaskDecoderConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+
+ self.num_multimask_outputs = config.num_multimask_outputs
+ self.num_mask_tokens = config.num_multimask_outputs + 1
+
+ self.iou_token = nn.Embedding(1, self.hidden_size)
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
+
+ self.transformer = Sam2VideoTwoWayTransformer(config)
+
+ # should we create a new class for this?
+ self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
+ self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
+ self.upscale_layer_norm = Sam2VideoLayerNorm(self.hidden_size // 4, data_format="channels_first")
+ self.activation = nn.GELU()
+
+ mlps_list = []
+ for _ in range(self.num_mask_tokens):
+ mlps_list += [Sam2VideoFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
+ self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
+ self.iou_prediction_head = Sam2VideoFeedForward(
+ self.hidden_size,
+ config.iou_head_hidden_dim,
+ self.num_mask_tokens,
+ config.iou_head_depth,
+ sigmoid_output=True,
+ )
+
+ self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1)
+ self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1)
+
+ self.obj_score_token = nn.Embedding(1, self.hidden_size)
+ self.pred_obj_score_head = Sam2VideoFeedForward(self.hidden_size, self.hidden_size, 1, 3)
+
+ self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
+ self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
+ self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_positional_embeddings: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ high_resolution_features: list[torch.Tensor],
+ attention_similarity: Optional[torch.Tensor] = None,
+ target_embedding: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Args:
+ image_embeddings (`torch.Tensor`):
+ The embeddings from the image encoder.
+ image_positional_embeddings (`torch.Tensor`):
+ Positional encoding with the shape of image_embeddings.
+ sparse_prompt_embeddings (`torch.Tensor`):
+ The embeddings of the points and boxes.
+ dense_prompt_embeddings (`torch.Tensor`):
+ The embeddings of the mask inputs.
+ multimask_output (`bool`):
+ Whether to return multiple masks or a single mask.
+ high_resolution_features (`list[torch.Tensor]`, *optional*):
+ The high-resolution features from the vision encoder.
+ attention_similarity (`torch.Tensor`, *optional*):
+ The attention similarity tensor.
+ target_embedding (`torch.Tensor`, *optional*):
+ The target embedding.
+ """
+ batch_size, num_channels, height, width = image_embeddings.shape
+ point_batch_size = sparse_prompt_embeddings.shape[1]
+ # Concatenate output tokens
+ output_tokens = torch.cat(
+ [
+ self.obj_score_token.weight,
+ self.iou_token.weight,
+ self.mask_tokens.weight,
+ ],
+ dim=0,
+ )
+ output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
+
+ if sparse_prompt_embeddings.shape[0] != 0:
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
+ else:
+ tokens = output_tokens
+ point_embeddings = tokens.to(self.iou_token.weight.dtype)
+
+ # Expand per-image data in batch direction to be per-mask
+ image_embeddings = image_embeddings + dense_prompt_embeddings
+ image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0)
+ image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
+ # Run the transformer
+ point_embeddings, image_embeddings = self.transformer(
+ point_embeddings=point_embeddings,
+ image_embeddings=image_embeddings,
+ image_positional_embeddings=image_positional_embeddings,
+ attention_similarity=attention_similarity,
+ target_embedding=target_embedding,
+ **kwargs,
+ )
+ iou_token_out = point_embeddings[:, :, 1, :]
+ mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ image_embeddings = image_embeddings.transpose(2, 3).view(
+ batch_size * point_batch_size, num_channels, height, width
+ )
+
+ feat_s0, feat_s1 = high_resolution_features
+ feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0)
+ feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0)
+ upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1
+ upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
+ upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0)
+
+ hyper_in_list: list[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ current_mlp = self.output_hypernetworks_mlps[i]
+ hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
+ hyper_in = torch.stack(hyper_in_list, dim=2)
+
+ _, num_channels, height, width = upscaled_embedding.shape
+ upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width)
+ masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+ object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :])
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ mask_slice = slice(1, None)
+ masks = masks[:, :, mask_slice, :, :]
+ iou_pred = iou_pred[:, :, mask_slice]
+ elif self.dynamic_multimask_via_stability and not self.training:
+ mask_slice = slice(0, 1)
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
+ else:
+ mask_slice = slice(0, 1)
+ masks = masks[:, :, mask_slice, :, :]
+ iou_pred = iou_pred[:, :, mask_slice]
+
+ sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
+
+ return masks, iou_pred, sam_tokens_out, object_score_logits
+
+ def _get_stability_scores(self, mask_logits):
+ """
+ Compute stability scores of the mask logits based on the IoU between upper and
+ lower thresholds.
+ """
+ mask_logits = mask_logits.flatten(-2)
+ stability_delta = self.dynamic_multimask_stability_delta
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
+ return stability_scores
+
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
+ """
+ When outputting a single mask, if the stability score from the current single-mask
+ output (based on output token 0) falls below a threshold, we instead select from
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
+ """
+ # The best mask from multimask output tokens (1~3)
+ multimask_logits = all_mask_logits[:, :, 1:, :, :]
+ multimask_iou_scores = all_iou_scores[:, :, 1:]
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P]
+ best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ best_scores_inds_expanded = best_scores_inds_expanded.expand(
+ -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)
+ )
+ best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W]
+ best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1]
+
+ # The mask from singlemask output token 0 and its stability score
+ singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
+ singlemask_iou_scores = all_iou_scores[:, :, 0:1]
+ stability_scores = self._get_stability_scores(singlemask_logits)
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
+
+ # Dynamically fall back to best multimask output upon low stability scores.
+ mask_logits_out = torch.where(
+ is_stable[..., None, None].expand_as(singlemask_logits),
+ singlemask_logits,
+ best_multimask_logits,
+ )
+ iou_scores_out = torch.where(
+ is_stable.expand_as(singlemask_iou_scores),
+ singlemask_iou_scores,
+ best_multimask_iou_scores,
+ )
+ return mask_logits_out, iou_scores_out
+
+
+# a large negative value as a placeholder score for missing objects
+NO_OBJ_SCORE = -1024.0
+
+
+def get_1d_sine_pe(pos_inds, dim, temperature=10000):
+ """
+ Get 1D sine positional embedding as in the original Transformer paper.
+ """
+ pe_dim = dim // 2
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
+
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
+ return pos_embed
+
+
+@auto_docstring
+class Sam2VideoModel(Sam2VideoPreTrainedModel):
+ _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
+ # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
+ _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
+ _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2VideoTwoWayAttentionBlock, index=2)}
+ _keys_to_ignore_on_load_unexpected = []
+
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__(config)
+ self.shared_image_embedding = Sam2VideoPositionalEmbedding(config.prompt_encoder_config)
+ self.vision_encoder = AutoModel.from_config(config.vision_config)
+ self.prompt_encoder = Sam2VideoPromptEncoder(config.prompt_encoder_config)
+ # The module using it is not a PreTrainedModel subclass so we need this
+ config.mask_decoder_config._attn_implementation = config._attn_implementation
+ self.mask_decoder = Sam2VideoMaskDecoder(config.mask_decoder_config)
+
+ self.num_feature_levels = config.vision_config.num_feature_levels
+ self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes
+ # a single token to indicate no memory embedding from previous frames
+ self.hidden_dim = config.vision_config.fpn_hidden_size
+ self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+ self.config = config
+ # For video sequence inference
+ self.image_size = config.image_size
+ self.memory_attention = Sam2VideoMemoryAttention(config)
+ self.memory_encoder = Sam2VideoMemoryEncoder(config)
+ self.no_memory_positional_encoding = torch.nn.Parameter(
+ torch.zeros(1, 1, config.vision_config.fpn_hidden_size)
+ )
+ self.mem_dim = config.memory_encoder_output_channels
+ self.num_maskmem = config.num_maskmem # Number of memories accessible
+ # Temporal encoding of the memories
+ self.memory_temporal_positional_encoding = torch.nn.Parameter(
+ torch.zeros(self.num_maskmem, 1, 1, self.mem_dim)
+ )
+
+ self.no_object_pointer = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
+ # A conv layer to downsample the mask prompt to stride 4 (the same stride as
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
+ # a feedforward layer on SAM output tokens to turn them into object pointers
+ self.object_pointer_proj = Sam2VideoFeedForward(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
+
+ if self.config.enable_temporal_pos_encoding_for_object_pointers:
+ # a linear projection on temporal positional encoding in object pointers to
+ # avoid potential interference with spatial positional encoding
+ self.temporal_positional_encoding_projection_layer = torch.nn.Linear(self.hidden_dim, self.mem_dim)
+ else:
+ self.temporal_positional_encoding_projection_layer = torch.nn.Identity()
+
+ self.occlusion_spatial_embedding_parameter = None # compatibility with Sam2
+ if config.enable_occlusion_spatial_embedding:
+ self.occlusion_spatial_embedding_parameter = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
+
+ self.post_init()
+
+ def _tie_weights(self):
+ self.prompt_encoder.shared_embedding.positional_embedding.data = (
+ self.shared_image_embedding.positional_embedding.data
+ )
+
+ def get_input_embeddings(self):
+ return self.vision_encoder.get_input_embeddings()
+
+ def get_image_wide_positional_embeddings(self) -> torch.Tensor:
+ size = self.prompt_encoder.image_embedding_size
+ target_device = self.shared_image_embedding.positional_embedding.device
+ target_dtype = self.shared_image_embedding.positional_embedding.dtype
+ grid = torch.ones(size, device=target_device, dtype=target_dtype)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / size[0]
+ x_embed = x_embed / size[1]
+
+ positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
+ return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
+
+ @torch.no_grad()
+ def get_image_embeddings(
+ self,
+ pixel_values: torch.FloatTensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> list[torch.Tensor]:
+ r"""
+ Returns the image embeddings by passing the pixel values through the vision encoder.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Input pixel values
+ """
+ batch_size = pixel_values.shape[0]
+ feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs)
+
+ # add no memory embedding to the last feature map
+ feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
+
+ # reshape feature maps to the same shape as the backbone feature sizes
+ image_embeddings = [
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+ for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
+ ]
+
+ return image_embeddings
+
+ @torch.no_grad()
+ def get_prompt_embeddings(
+ self,
+ input_points: Optional[torch.FloatTensor] = None,
+ input_labels: Optional[torch.LongTensor] = None,
+ input_boxes: Optional[torch.FloatTensor] = None,
+ input_masks: Optional[torch.LongTensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
+
+ Args:
+ input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
+ Optional input points for the prompt encoder. The padding of the point is automatically done by the
+ processor. `point_batch_size` refers to the number of masks that we want the model to predict per
+ point. The model will output `point_batch_size` times 3 masks in total.
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
+ Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
+ processor, or can be fed by the user.
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
+ Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
+ processor. users can also pass manually the input boxes.
+ input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
+ Optional input masks for the prompt encoder.
+ """
+ prompt_output = self.prompt_encoder(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+ return prompt_output
+
+ @torch.inference_mode()
+ @auto_docstring(custom_intro="Propagate the objects through a streamed video frame.")
+ def forward(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ frame_idx: Optional[int] = None,
+ frame: Optional[torch.Tensor] = None,
+ reverse: bool = False,
+ ) -> Sam2VideoSegmentationOutput:
+ r"""
+ inference_session (`Sam2VideoInferenceSession`):
+ The video inference session object.
+ frame_idx (`int`, *optional*):
+ The index of the frame on which to run inference. No need to provide when inferring
+ on a new streamed frame.
+ frame (`torch.Tensor`, *optional*):
+ The frame to process. Provide when streaming.
+ reverse (`bool`, *optional*, defaults to `False`):
+ Whether to propagate in reverse.
+ """
+ if frame is not None:
+ frame_idx = inference_session.add_new_frame(frame, frame_idx)
+
+ if frame is not None and inference_session.get_obj_num() == 0:
+ raise ValueError("No objects are provided for tracking; please add inputs first.")
+
+ num_objects = inference_session.get_obj_num()
+ pred_masks_per_obj = [None] * num_objects
+ # Note: We avoid batched inference here because per-object inputs (clicks/masks)
+ # can differ across objects.
+ for obj_idx in range(num_objects):
+ obj_id = inference_session.obj_idx_to_id(obj_idx)
+ has_new_inputs = obj_id in inference_session.obj_with_new_inputs
+ has_cond_output = frame_idx in inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
+ # If this object has no new inputs and this frame already has a
+ # conditioning output, reuse the cached masks instead of recomputing.
+ if (not has_new_inputs) and has_cond_output:
+ pred_masks = inference_session.get_output(obj_idx, frame_idx, "pred_masks", is_conditioning_frame=True)
+ is_init_cond_frame = True
+ else:
+ # Defaults when there are no new inputs
+ is_init_cond_frame = False
+ point_inputs = None
+ mask_inputs = None
+
+ if has_new_inputs:
+ is_init_cond_frame = frame_idx not in inference_session.frames_tracked_per_obj[obj_idx]
+ if is_init_cond_frame:
+ reverse = False
+ point_inputs = inference_session.point_inputs_per_obj[obj_idx].get(frame_idx, None)
+ mask_inputs = inference_session.mask_inputs_per_obj[obj_idx].get(frame_idx, None)
+ if point_inputs is not None or mask_inputs is not None:
+ inference_session.obj_with_new_inputs.remove(obj_id)
+
+ current_out = self._run_single_frame_inference(
+ inference_session=inference_session,
+ obj_idx=obj_idx,
+ frame_idx=frame_idx,
+ batch_size=1, # run on the slice of a single object
+ is_init_cond_frame=is_init_cond_frame,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ reverse=reverse,
+ run_mem_encoder=True,
+ streaming=frame is not None,
+ )
+ inference_session.store_output(
+ obj_idx, frame_idx, output_value=current_out, is_conditioning_frame=is_init_cond_frame
+ )
+ pred_masks = current_out["pred_masks"]
+
+ pred_masks_per_obj[obj_idx] = pred_masks
+ if not is_init_cond_frame:
+ # only for tracked frames, not for initial conditioning frames
+ inference_session.frames_tracked_per_obj[obj_idx][frame_idx] = {"reverse": reverse}
+
+ # Resize the output mask to the original video resolution (we directly use
+ # the mask scores on GPU for output to avoid any CPU conversion in between)
+ if len(pred_masks_per_obj) > 1:
+ all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
+ else:
+ all_pred_masks = pred_masks_per_obj[0]
+
+ return Sam2VideoSegmentationOutput(pred_masks=all_pred_masks, frame_idx=frame_idx)
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[
+ list[torch.Tensor],
+ list[torch.Tensor],
+ Optional[tuple[torch.FloatTensor, ...]],
+ Optional[tuple[torch.FloatTensor, ...]],
+ ]:
+ r"""
+ Extract and preprocess image features using the vision encoder.
+
+ Args:
+ pixel_values (`torch.FloatTensor`):
+ Input pixel values of shape `(batch_size, num_channels, height, width)`.
+
+ Returns:
+ `tuple`: A tuple containing:
+ - feature_maps (`list[torch.Tensor]`): List of feature maps from different levels.
+ - feature_maps_position_embeddings (`list[torch.Tensor]`): List of positional embeddings for each feature level.
+ - vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Hidden states from the vision encoder.
+ - vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from the vision encoder.
+ """
+ vision_outputs: Sam2VideoVisionEncoderOutput = self.vision_encoder(
+ pixel_values,
+ **kwargs,
+ )
+
+ feature_maps = vision_outputs.fpn_hidden_states
+ feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
+
+ # precompute projected level 0 and level 1 features in SAM decoder
+ # to avoid running it again on every SAM click
+ feature_maps = list(feature_maps)
+ feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0])
+ feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1])
+
+ # flatten NxCxHxW to HWxNxC
+ feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps]
+ feature_maps_position_embeddings = [
+ feature_map_position_embedding.flatten(2).permute(2, 0, 1)
+ for feature_map_position_embedding in feature_maps_position_embeddings
+ ]
+
+ return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions
+
+ def _prepare_vision_features(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ frame_idx: int,
+ batch_size: int,
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
+ """Prepare vision features for a frame."""
+
+ # Check if features are cached
+ if cached_features := inference_session.cache.get_vision_features(frame_idx):
+ vision_feats = cached_features["vision_feats"]
+ vision_pos_embeds = cached_features["vision_pos_embeds"]
+ else:
+ # Compute features using image encoder
+ image_batch = inference_session.get_frame(frame_idx).unsqueeze(0) # Add batch dimension
+ vision_feats, vision_pos_embeds, _, _ = self.get_image_features(image_batch)
+ # Cache features
+ inference_session.cache.cache_vision_features(
+ frame_idx, {"vision_feats": vision_feats, "vision_pos_embeds": vision_pos_embeds}
+ )
+
+ # Expand to batch size if needed
+ if batch_size > 1:
+ vision_feats = vision_feats.expand(batch_size, -1, -1, -1)
+ vision_pos_embeds = [pe.expand(batch_size, -1, -1, -1) for pe in vision_pos_embeds]
+
+ return vision_feats, vision_pos_embeds
+
+ def _single_frame_forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ input_points: Optional[torch.FloatTensor] = None,
+ input_labels: Optional[torch.LongTensor] = None,
+ input_boxes: Optional[torch.FloatTensor] = None,
+ input_masks: Optional[torch.LongTensor] = None,
+ image_embeddings: Optional[torch.FloatTensor] = None,
+ multimask_output: bool = True,
+ attention_similarity: Optional[torch.FloatTensor] = None,
+ target_embedding: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Sam2VideoImageSegmentationOutput:
+ """
+ input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
+ Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
+ better results. The points can be obtained by passing a list of list of list to the processor that will
+ create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
+ second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
+ per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
+ multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
+ coordinates of the point. If a different number of points is passed either for each image, or for each
+ mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
+ computation of the embedding will be skipped for these points using the labels.
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
+ Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
+ official implementation, there are 3 types of labels
+
+ - `1`: the point is a point that contains the object of interest
+ - `0`: the point is a point that does not contain the object of interest
+ - `-1`: the point corresponds to the background
+
+ We added the label:
+
+ - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
+
+ The padding labels should be automatically done by the processor.
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
+ Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
+ much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
+ that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
+ size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
+ In the order (`x1`, `y1`, `x2`, `y2`):
+
+ - `x1`: the x coordinate of the top left point of the input box
+ - `y1`: the y coordinate of the top left point of the input box
+ - `x2`: the x coordinate of the bottom right point of the input box
+ - `y2`: the y coordinate of the bottom right point of the input box
+ input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
+ SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
+ generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
+ manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
+ image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
+ Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
+ efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
+ method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
+ multimask_output (`bool`, *optional*):
+ In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
+ bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
+ "best" mask, by specifying `multimask_output=False`.
+ attention_similarity (`torch.FloatTensor`, *optional*):
+ Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
+ model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
+ target_embedding (`torch.FloatTensor`, *optional*):
+ Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
+ the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
+ """
+ if not ((pixel_values is None) ^ (image_embeddings is None)):
+ raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
+ if input_points is not None and input_boxes is not None:
+ if input_points.shape[1] != input_boxes.shape[1]:
+ raise ValueError(
+ f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
+ )
+ elif input_points is not None:
+ num_objects = input_points.shape[1]
+ elif input_boxes is not None:
+ num_objects = input_boxes.shape[1]
+ elif input_masks is not None:
+ num_objects = input_masks.shape[1]
+ else:
+ num_objects = 1
+
+ image_positional_embeddings = self.get_image_wide_positional_embeddings()
+ # repeat with batch size
+ batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
+ image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
+
+ vision_attentions = None
+ vision_hidden_states = None
+
+ if pixel_values is not None:
+ feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features(
+ pixel_values,
+ **kwargs,
+ )
+
+ # add no memory embedding to the last feature map
+ feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
+
+ # reshape feature maps to the same shape as the backbone feature sizes
+ image_embeddings = [
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+ for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
+ ]
+
+ if input_points is not None and input_labels is None:
+ input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
+
+ if input_points is None and input_boxes is None:
+ # If no points are provide, pad with an empty point (with label -1)
+ input_points = torch.zeros(
+ batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
+ )
+ input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
+
+ if input_masks is not None:
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
+ # and feed it as a dense mask prompt into the SAM mask encoder
+ if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
+ input_masks = F.interpolate(
+ input_masks.float(),
+ size=self.prompt_encoder.mask_input_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ ).to(input_masks.dtype)
+
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+ low_res_multimasks, iou_scores, sam_output_tokens, object_score_logits = self.mask_decoder(
+ image_embeddings=image_embeddings[-1],
+ image_positional_embeddings=image_positional_embeddings,
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ high_resolution_features=image_embeddings[:-1],
+ attention_similarity=attention_similarity,
+ target_embedding=target_embedding,
+ **kwargs,
+ )
+
+ is_obj_appearing = object_score_logits > 0
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
+ # consistent with the actual mask prediction
+ low_res_multimasks = torch.where(
+ is_obj_appearing[:, None, None],
+ low_res_multimasks,
+ NO_OBJ_SCORE,
+ )
+
+ # convert masks from possibly bfloat16 (or float16) to float32
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
+ high_res_multimasks = (
+ F.interpolate(
+ low_res_multimasks.squeeze(1).float(),
+ size=(self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ .unsqueeze(1)
+ .to(low_res_multimasks.dtype)
+ )
+ sam_output_token = sam_output_tokens[:, :, 0]
+ if multimask_output:
+ # take the best mask prediction (with the highest IoU estimation)
+ best_iou_inds = torch.argmax(iou_scores, dim=-1)
+ batch_inds = torch.arange(batch_size, device=high_res_multimasks.device)
+ object_batch_inds = torch.arange(num_objects, device=high_res_multimasks.device)
+ low_res_masks = low_res_multimasks[batch_inds, object_batch_inds, best_iou_inds]
+ high_res_masks = high_res_multimasks[batch_inds, object_batch_inds, best_iou_inds]
+ if sam_output_tokens.size(2) > 1:
+ sam_output_token = sam_output_tokens[batch_inds, object_batch_inds, best_iou_inds]
+ else:
+ low_res_masks, high_res_masks = low_res_multimasks[:, :, 0], high_res_multimasks[:, :, 0]
+
+ # Extract object pointer from the SAM output token (with occlusion handling)
+ object_pointer = self.object_pointer_proj(sam_output_token)
+ lambda_is_obj_appearing = is_obj_appearing.to(object_pointer.dtype)
+
+ object_pointer = lambda_is_obj_appearing * object_pointer
+ object_pointer = object_pointer + (1 - lambda_is_obj_appearing) * self.no_object_pointer
+
+ return Sam2VideoImageSegmentationOutput(
+ iou_scores=iou_scores,
+ pred_masks=low_res_masks,
+ high_res_masks=high_res_masks,
+ object_pointer=object_pointer,
+ object_score_logits=object_score_logits,
+ image_embeddings=image_embeddings,
+ vision_hidden_states=vision_hidden_states,
+ vision_attentions=vision_attentions,
+ )
+
+ def _use_mask_as_output(
+ self,
+ backbone_features: torch.Tensor,
+ high_res_features: list[torch.Tensor],
+ mask_inputs: torch.Tensor,
+ ) -> Sam2VideoImageSegmentationOutput:
+ """
+ Directly turn binary `mask_inputs` into a output mask logits without using SAM.
+ (same input and output shapes as in forward above).
+ """
+ # Use -10/+20 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
+ mask_inputs_float = mask_inputs.to(backbone_features[0].dtype)
+ high_res_masks = mask_inputs_float * out_scale + out_bias
+ low_res_masks = F.interpolate(
+ high_res_masks.float(),
+ size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ ).to(backbone_features[0].dtype)
+ # a dummy IoU prediction of all 1's under mask input
+ iou_scores = mask_inputs.new_ones(mask_inputs.size(0), 1).to(backbone_features[0].dtype)
+ # produce an object pointer using the SAM decoder from the mask input
+ object_pointer = self._single_frame_forward(
+ input_masks=self.mask_downsample(mask_inputs_float.to(backbone_features[0].dtype)),
+ image_embeddings=high_res_features + [backbone_features],
+ ).object_pointer
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
+ # on the object_scores from the SAM decoder.
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
+ is_obj_appearing = is_obj_appearing[..., None]
+ lambda_is_obj_appearing = is_obj_appearing.to(backbone_features[0].dtype)
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
+ object_pointer = lambda_is_obj_appearing * object_pointer
+ object_pointer = object_pointer + (1 - lambda_is_obj_appearing) * self.no_object_pointer
+ return Sam2VideoImageSegmentationOutput(
+ iou_scores=iou_scores,
+ pred_masks=low_res_masks,
+ high_res_masks=high_res_masks,
+ object_pointer=object_pointer,
+ object_score_logits=object_score_logits,
+ image_embeddings=high_res_features + [backbone_features],
+ )
+
+ def _gather_memory_frame_outputs(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ obj_idx: int,
+ frame_idx: int,
+ track_in_reverse_time: bool = False,
+ ) -> list[tuple[int, dict]]:
+ """
+ Get memory frames from conditioning and non-conditioning outputs.
+
+ Returns:
+ List of (relative_temporal_offset, output_data) tuples.
+ """
+ temporal_positions_and_previous_outputs = []
+
+ # Add conditioning frame outputs (no limit here, as is the case in the original checkpoints)
+ conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
+ if not conditioning_outputs:
+ raise ValueError(
+ "maskmem_features in conditioning outputs cannot be empty when not is_initial_conditioning_frame"
+ )
+
+ # Store (temporal_position, output_data) tuples
+ temporal_positions_and_previous_outputs = [(0, out) for out in conditioning_outputs.values()]
+
+ # Add non-conditioning memory frames (up to self.num_maskmem - 1)
+ # These are typically frames tracked by the model without direct user input.
+ # Frames are selected with a stride, prioritizing the most recent ones. Here we only support stride = 1 for simplicity.
+ for relative_temporal_offset in range(self.num_maskmem - 1, 0, -1):
+ # relative_temporal_offset: how many frames before (or after if reversing) the current frame
+ if not track_in_reverse_time:
+ previous_frame_idx = frame_idx - relative_temporal_offset
+ else:
+ previous_frame_idx = frame_idx + relative_temporal_offset
+
+ # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU
+ output_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get(
+ previous_frame_idx, None
+ )
+
+ temporal_positions_and_previous_outputs.append((relative_temporal_offset, output_data))
+
+ return temporal_positions_and_previous_outputs
+
+ def _build_memory_attention_inputs(
+ self,
+ temporal_positions_and_previous_outputs: list[tuple[int, dict]],
+ device: torch.device,
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
+ """
+ Concatenate memory features and positional embeddings from previous frames.
+
+ Returns:
+ Tuple of (memories_to_concatenate, memory_positional_embeddings_to_concatenate).
+ """
+ memories_to_concatenate = []
+ memory_positional_embeddings_to_concatenate = []
+
+ for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs:
+ if prev_output_data is None:
+ continue # Skip if no output data for this temporal position (e.g., padding frames)
+
+ # Load memory features (potentially from CPU to GPU)
+ # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels)
+ memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True)
+ memories_to_concatenate.append(memory_features)
+
+ # Spatial positional encoding (potentially from CPU to GPU)
+ spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True)
+
+ # Add temporal positional encoding
+ # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim)
+ combined_memory_pos_embed = (
+ spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1]
+ )
+ memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed)
+
+ return memories_to_concatenate, memory_positional_embeddings_to_concatenate
+
+ def _get_object_pointers(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ obj_idx: int,
+ frame_idx: int,
+ num_total_frames: int,
+ device: torch.device,
+ track_in_reverse_time: bool = False,
+ streaming: bool = False,
+ ) -> tuple[list[int], list[torch.Tensor], int]:
+ """
+ Get object pointers and their positional embeddings from past frames.
+
+ Returns:
+ Tuple of (temporal_offsets, pointer_tokens, max_object_pointers_to_use).
+ """
+ temporal_position_sign_multiplier = -1 if track_in_reverse_time else 1
+
+ # Determine max object pointers to use
+ if streaming:
+ max_object_pointers_to_use = self.config.max_object_pointers_in_encoder
+ else:
+ max_object_pointers_to_use = min(num_total_frames, self.config.max_object_pointers_in_encoder)
+
+ temporal_offsets: list[int] = []
+ pointer_tokens: list[torch.Tensor] = []
+
+ # Add object pointers from selected conditioning frames
+ # Optionally, only include pointers from past frames during evaluation
+ conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
+ eligible_conditioning_outputs = conditioning_outputs
+ if not self.training:
+ eligible_conditioning_outputs = {
+ temporal_idx: out
+ for temporal_idx, out in conditioning_outputs.items()
+ if (temporal_idx >= frame_idx if track_in_reverse_time else temporal_idx <= frame_idx)
+ }
+
+ for temporal_idx, out_data in eligible_conditioning_outputs.items():
+ temporal_difference = (frame_idx - temporal_idx) * temporal_position_sign_multiplier
+ temporal_offsets.append(temporal_difference)
+ pointer_tokens.append(out_data["object_pointer"].to(device))
+
+ # Add object pointers from non-conditioning frames (up to max_object_pointers_to_use - 1)
+ for t_diff_offset in range(1, max_object_pointers_to_use):
+ ref_frame_idx = frame_idx + t_diff_offset if track_in_reverse_time else frame_idx - t_diff_offset
+ if ref_frame_idx < 0 or (
+ not streaming and num_total_frames is not None and ref_frame_idx >= num_total_frames
+ ):
+ break # Stop if frame index is out of bounds
+
+ # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU
+ out_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get(
+ ref_frame_idx, None
+ )
+ if out_data is not None:
+ temporal_offsets.append(t_diff_offset)
+ pointer_tokens.append(out_data["object_pointer"].to(device))
+
+ return temporal_offsets, pointer_tokens, max_object_pointers_to_use
+
+ def _process_object_pointers(
+ self,
+ temporal_offsets: list[int],
+ pointer_tokens: list[torch.Tensor],
+ max_object_pointers_to_use: int,
+ batch_size: int,
+ num_channels: int,
+ device: torch.device,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Process object pointers and compute their positional embeddings.
+
+ Returns:
+ Tuple of (object_pointers, object_pointers_pos_embed).
+ """
+ if not pointer_tokens:
+ return None, None
+
+ # Stack object pointers: List of (Batch, Channels) -> (SeqLen_ptr, Batch, Channels)
+ object_pointers = torch.stack(pointer_tokens, dim=0)
+
+ if self.config.enable_temporal_pos_encoding_for_object_pointers:
+ max_temporal_diff = float(max_object_pointers_to_use - 1)
+ # Determine dimensionality for temporal positional encoding of pointers
+ pointer_tpos_dim = num_channels
+
+ # Normalize temporal differences before sine PE calculation
+ normalized_temporal_diffs = (
+ torch.tensor(temporal_offsets, device=device, dtype=torch.float32) / max_temporal_diff
+ )
+ sine_pe = get_1d_sine_pe(normalized_temporal_diffs, dim=pointer_tpos_dim).to(object_pointers.dtype)
+ projected_sine_pe = self.temporal_positional_encoding_projection_layer(sine_pe)
+ object_pointers_pos_embed = projected_sine_pe.unsqueeze(1).expand(-1, batch_size, self.mem_dim)
+ else:
+ object_pointers_pos_embed = object_pointers.new_zeros(
+ len(temporal_offsets), batch_size, self.mem_dim, dtype=object_pointers.dtype
+ )
+
+ if self.mem_dim < num_channels:
+ # If memory dimension is smaller, reshape/split pointers and repeat positional encoding
+ num_splits = num_channels // self.mem_dim
+ object_pointers = object_pointers.reshape(-1, batch_size, num_splits, self.mem_dim)
+ object_pointers = object_pointers.permute(0, 2, 1, 3).flatten(
+ 0, 1
+ ) # (SeqLen_ptr*num_splits, Batch, MemDim)
+ object_pointers_pos_embed = object_pointers_pos_embed.repeat_interleave(num_splits, dim=0)
+
+ return object_pointers, object_pointers_pos_embed
+
+ def _prepare_memory_conditioned_features(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ frame_idx: int,
+ obj_idx: int,
+ is_initial_conditioning_frame: bool,
+ current_vision_features: list[torch.Tensor],
+ current_vision_positional_embeddings: list[torch.Tensor],
+ num_total_frames: int,
+ track_in_reverse_time: bool = False,
+ streaming: bool = False,
+ ) -> torch.Tensor:
+ """
+ Fuse current frame's visual features with memory from previous frames for enhanced object tracking.
+
+ This method conditions the current frame's visual features on temporal memory from previous frames,
+ enabling consistent object tracking across video sequences. For initial conditioning frames, it uses
+ no-memory embeddings. For subsequent frames, it retrieves and integrates memory features from both
+ conditioning frames (user interactions) and non-conditioning frames (tracked results) via cross-attention.
+
+ Args:
+ inference_session (`Sam2VideoInferenceSession`):
+ The video inference session object.
+ frame_idx (`int`):
+ Index of the current frame being processed.
+ obj_idx (`int`):
+ Index of the object being processed.
+ is_initial_conditioning_frame (`bool`):
+ Whether this is an initial conditioning frame with user inputs (True) or a subsequent
+ tracking frame (False).
+ current_vision_features (`torch.Tensor`):
+ Highest-level vision features of shape `(seq_len, batch_size, channels)`.
+ current_vision_positional_embeddings (`torch.Tensor`):
+ Positional embedding tensors corresponding to the highest-level vision features.
+ num_total_frames (`int`):
+ Total number of frames in the video sequence.
+ track_in_reverse_time (`bool`, *optional*, defaults to `False`):
+ Whether tracking is performed in reverse temporal order.
+ streaming (`bool`, *optional*, defaults to `False`):
+ Whether this is streaming inference mode.
+
+ Returns:
+ `torch.Tensor`: Memory-conditioned feature tensor of shape `(batch_size, channels, height, width)`
+ suitable for input to the SAM decoder.
+ """
+ # Get dimensions from the highest-level (lowest-resolution) feature map
+ batch_size = current_vision_features.size(1)
+ num_channels = self.hidden_dim
+ height, width = self.backbone_feature_sizes[-1]
+ device = current_vision_features.device
+
+ # If memory is disabled (e.g., for single image SAM), return current features directly.
+ if self.num_maskmem == 0:
+ # Permute (SeqLen, Batch, Channels) -> (Batch, Channels, SeqLen) then view as (Batch, Channels, Height, Width)
+ # Assuming SeqLen = Height * Width for the last feature map
+ current_feature_map = current_vision_features.permute(1, 2, 0).view(
+ batch_size, num_channels, height, width
+ )
+ return current_feature_map
+
+ # Step 1: Handle initial conditioning frames
+ if is_initial_conditioning_frame:
+ # For initial conditioning frames, no prior memory is used directly in this block.
+ # If configured, directly add a learnable "no memory" embedding.
+ # current_vision_features has shape (SeqLen, Batch, Channels)
+ conditioned_feature_map_flat = current_vision_features + self.no_memory_embedding
+ # Reshape to (Batch, Channels, Height, Width)
+ conditioned_feature_map = conditioned_feature_map_flat.permute(1, 2, 0).view(
+ batch_size, num_channels, height, width
+ )
+ return conditioned_feature_map
+
+ # Step 2: Get memory frames and concatenate their features
+ temporal_positions_and_previous_outputs = self._gather_memory_frame_outputs(
+ inference_session, obj_idx, frame_idx, track_in_reverse_time
+ )
+
+ memories_to_concatenate, memory_positional_embeddings_to_concatenate = self._build_memory_attention_inputs(
+ temporal_positions_and_previous_outputs, device
+ )
+
+ # Step 3: Get and process object pointers
+ temporal_offsets, pointer_tokens, max_object_pointers_to_use = self._get_object_pointers(
+ inference_session, obj_idx, frame_idx, num_total_frames, device, track_in_reverse_time, streaming
+ )
+
+ num_object_pointer_tokens = 0
+ if pointer_tokens:
+ object_pointers, object_pointers_pos_embed = self._process_object_pointers(
+ temporal_offsets, pointer_tokens, max_object_pointers_to_use, batch_size, num_channels, device
+ )
+
+ if object_pointers is not None:
+ memories_to_concatenate.append(object_pointers)
+ memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed)
+ num_object_pointer_tokens = object_pointers.shape[0]
+
+ # Step 4: Concatenate all retrieved memories and their positional embeddings
+ combined_memory = torch.cat(memories_to_concatenate, dim=0)
+ combined_memory_positional_embeddings = torch.cat(memory_positional_embeddings_to_concatenate, dim=0)
+
+ # Step 5: Forward through the memory attention mechanism
+ conditioned_feature_map_flat = self.memory_attention(
+ current_vision_features=current_vision_features,
+ current_vision_position_embeddings=current_vision_positional_embeddings,
+ memory=combined_memory,
+ memory_posision_embeddings=combined_memory_positional_embeddings, # Corrected typo from API
+ num_object_pointer_tokens=num_object_pointer_tokens,
+ )
+
+ # Reshape from (Batch, H*W, Channels) to (Batch, Channels, Height, Width)
+ conditioned_feature_map = (
+ conditioned_feature_map_flat.squeeze(1).permute(0, 2, 1).view(batch_size, num_channels, height, width)
+ )
+ return conditioned_feature_map
+
+ def _use_multimask(self, is_init_cond_frame: bool, point_inputs: Optional[dict]) -> bool:
+ """Whether to use multimask output in the SAM head."""
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(2)
+ multimask_output = (
+ self.config.multimask_output_in_sam
+ and (is_init_cond_frame or self.config.multimask_output_for_tracking)
+ and (self.config.multimask_min_pt_num <= num_pts <= self.config.multimask_max_pt_num)
+ )
+ return multimask_output
+
+ def _run_single_frame_inference(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ frame_idx: int,
+ obj_idx: int,
+ batch_size: int,
+ is_init_cond_frame: bool,
+ point_inputs: Optional[torch.Tensor],
+ mask_inputs: Optional[torch.Tensor],
+ reverse: bool,
+ run_mem_encoder: bool,
+ prev_sam_mask_logits: Optional[torch.Tensor] = None,
+ streaming: bool = False,
+ ) -> dict[str, Any]:
+ """
+ Perform a single tracking step for video object segmentation.
+
+ Args:
+ inference_session (`Sam2VideoInferenceSession`):
+ The video inference session object.
+ frame_idx (`int`):
+ Index of the current frame.
+ obj_idx (`int`):
+ Index of the current object.
+ batch_size (`int`):
+ Batch size of the current frame.
+ is_init_cond_frame (`bool`):
+ Whether this is an initial conditioning frame with user inputs.
+ point_inputs (`dict`, *optional*):
+ Point prompt inputs for the current frame.
+ mask_inputs (`torch.Tensor`, *optional*):
+ Mask prompt inputs for the current frame.
+ reverse (`bool`, *optional*, defaults to `False`):
+ Whether to track in reverse time order.
+ run_mem_encoder (`bool`, *optional*, defaults to `True`):
+ Whether to run the memory encoder on predicted masks.
+ prev_sam_mask_logits (`torch.Tensor`, *optional*):
+ Previously predicted SAM mask logits that can be fed with new clicks.
+ streaming (`bool`, *optional*, defaults to `False`):
+ Whether this is streaming inference.
+
+ Returns:
+ `dict`: Dictionary containing the tracking results for the current frame, including:
+ - pred_masks: Predicted low-resolution masks.
+ - object_pointer: Object pointer for memory.
+ - object_score_logits: Object score logits (inference only).
+ - maskmem_features: Memory features for future frames.
+ - maskmem_pos_enc: Memory positional encodings.
+ """
+ # Retrieve correct image features
+ current_vision_feats, current_vision_pos_embeds = self._prepare_vision_features(
+ inference_session, frame_idx, batch_size
+ )
+ # point and mask should not appear as input simultaneously on the same frame
+ if point_inputs is not None and mask_inputs is not None:
+ raise ValueError(
+ "point_inputs and mask_inputs should not appear as input simultaneously on the same frame"
+ )
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
+ if len(current_vision_feats) > 1:
+ high_res_features = [
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
+ for x, s in zip(current_vision_feats[:-1], self.backbone_feature_sizes[:-1])
+ ]
+ else:
+ high_res_features = None
+ if mask_inputs is not None:
+ # We directly output the mask input (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *self.backbone_feature_sizes[-1])
+ sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
+ else:
+ # fused the visual feature with previous memory features in the memory bank
+ pix_feat = self._prepare_memory_conditioned_features(
+ inference_session=inference_session,
+ frame_idx=frame_idx,
+ obj_idx=obj_idx,
+ is_initial_conditioning_frame=is_init_cond_frame,
+ current_vision_features=current_vision_feats[-1],
+ current_vision_positional_embeddings=current_vision_pos_embeds[-1],
+ num_total_frames=inference_session.num_frames,
+ track_in_reverse_time=reverse,
+ streaming=streaming,
+ )
+ # apply SAM-style segmentation head
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
+ if prev_sam_mask_logits is not None:
+ mask_inputs = prev_sam_mask_logits
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
+ sam_outputs = self._single_frame_forward(
+ pixel_values=None, # Vision features already computed
+ input_points=point_inputs["point_coords"] if point_inputs is not None else None,
+ input_labels=point_inputs["point_labels"] if point_inputs is not None else None,
+ input_masks=mask_inputs,
+ image_embeddings=high_res_features + [pix_feat],
+ multimask_output=multimask_output,
+ )
+
+ # Finally run the memory encoder on the predicted mask to encode
+ # it into a new memory feature (which will be used to condition vision features in future frames)
+ maskmem_features = None
+ maskmem_pos_enc = None
+ if run_mem_encoder and self.num_maskmem > 0:
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+ current_vision_feats=current_vision_feats[-1],
+ pred_masks_high_res=sam_outputs.high_res_masks,
+ object_score_logits=sam_outputs.object_score_logits,
+ is_mask_from_pts=(point_inputs is not None or mask_inputs is not None),
+ )
+
+ current_out = {
+ "pred_masks": sam_outputs.pred_masks,
+ "object_pointer": sam_outputs.object_pointer,
+ "maskmem_features": maskmem_features if maskmem_features is not None else None,
+ "maskmem_pos_enc": maskmem_pos_enc,
+ }
+ if not self.training:
+ current_out["object_score_logits"] = sam_outputs.object_score_logits
+
+ return current_out
+
+ def _encode_new_memory(
+ self,
+ current_vision_feats: torch.Tensor,
+ pred_masks_high_res: torch.Tensor,
+ object_score_logits: torch.Tensor,
+ is_mask_from_pts: bool,
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
+ """Encode the current image and its prediction into a memory feature."""
+ batch_size = current_vision_feats.size(1) # batch size on this frame
+ channels = self.hidden_dim
+ height, width = self.backbone_feature_sizes[-1] # top-level (lowest-resolution) feature size
+ # top-level feature, (HW)BC => BCHW
+ pix_feat = current_vision_feats.permute(1, 2, 0).view(batch_size, channels, height, width)
+ if is_mask_from_pts and not self.training:
+ # binarize the mask logits
+ mask_for_mem = (pred_masks_high_res > 0).to(pred_masks_high_res.dtype)
+ else:
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
+ # apply scale and bias terms to the sigmoid probabilities
+ mask_for_mem = mask_for_mem * self.config.sigmoid_scale_for_mem_enc
+ mask_for_mem = mask_for_mem + self.config.sigmoid_bias_for_mem_enc
+
+ maskmem_features, maskmem_pos_enc = self.memory_encoder(
+ pix_feat,
+ mask_for_mem,
+ )
+ # add a no-object embedding to the spatial memory to indicate that the frame
+ # is predicted to be occluded (i.e. no object is appearing in the frame)
+ if self.occlusion_spatial_embedding_parameter is not None:
+ is_obj_appearing = (object_score_logits > 0).float()
+ maskmem_features += (1 - is_obj_appearing[..., None]) * self.occlusion_spatial_embedding_parameter[
+ ..., None, None
+ ].expand(*maskmem_features.shape)
+
+ # convert to bfloat16 to save memory, and for consistency with the original implementation
+ maskmem_features = maskmem_features.to(torch.bfloat16).flatten(2).permute(2, 0, 1)
+ maskmem_pos_enc = maskmem_pos_enc.to(pred_masks_high_res.dtype).flatten(2).permute(2, 0, 1)
+
+ return maskmem_features, maskmem_pos_enc
+
+ @torch.inference_mode()
+ @auto_docstring(
+ custom_intro="""
+ Propagate the objects through the video frames. Used when initializing an inference session with a whole video.
+ Yields Sam2VideoSegmentationOutput for each frame.
+ """
+ )
+ def propagate_in_video_iterator(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ start_frame_idx: Optional[int] = None,
+ max_frame_num_to_track: Optional[int] = None,
+ reverse: bool = False,
+ ) -> Iterator[Sam2VideoSegmentationOutput]:
+ r"""
+ inference_session (`Sam2VideoInferenceSession`):
+ The video inference session object.
+ start_frame_idx (`int`, *optional*):
+ The starting frame index for propagation.
+ Need to be provided if `forward` hasn't been called on new inputs yet.
+ If not provided, the starting frame index will be the earliest frame with input points.
+ max_frame_num_to_track (`int`, *optional*):
+ The maximum number of frames to track.
+ reverse (`bool`, *optional*, defaults to `False`):
+ Whether to propagate in reverse.
+ """
+ num_frames = inference_session.num_frames
+
+ # set start index, end index, and processing order
+ if start_frame_idx is None:
+ # default: start from the earliest frame with input points
+ frames_with_inputs = [
+ frame_idx
+ for obj_output_dict in inference_session.output_dict_per_obj.values()
+ for frame_idx in obj_output_dict["cond_frame_outputs"]
+ ]
+ if not frames_with_inputs:
+ raise ValueError(
+ "Cannot determine the starting frame index; please specify it manually, or run inference on a frame with inputs first."
+ )
+ start_frame_idx = min(frames_with_inputs)
+ if max_frame_num_to_track is None:
+ # default: track all the frames in the video
+ max_frame_num_to_track = num_frames
+ if reverse:
+ end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
+ if start_frame_idx > 0:
+ processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
+ else:
+ processing_order = [] # skip reverse tracking if starting from frame 0
+ else:
+ end_frame_idx = min(start_frame_idx + max_frame_num_to_track, num_frames - 1)
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
+
+ for frame_idx in tqdm(processing_order, desc="propagate in video"):
+ sam2_video_output = self(inference_session, frame_idx=frame_idx, reverse=reverse)
+ yield sam2_video_output
+
+
+__all__ = ["Sam2VideoModel", "Sam2VideoInferenceSession", "Sam2VideoPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2_video/modular_sam2_video.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2_video/modular_sam2_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..b95a9f778251a6ced837d74c5ba9ae343fc608f8
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/sam2_video/modular_sam2_video.py
@@ -0,0 +1,2430 @@
+# coding=utf-8
+# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch SAM 2 model."""
+
+import math
+from collections import OrderedDict
+from collections.abc import Iterator
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from tqdm import tqdm
+
+from ...activations import ACT2FN
+from ...configuration_utils import PretrainedConfig
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import ProcessorMixin, Unpack
+from ...utils import (
+ ModelOutput,
+ auto_docstring,
+ logging,
+)
+from ...utils.generic import OutputRecorder, TransformersKwargs
+from ...video_utils import VideoInput
+from ..auto import CONFIG_MAPPING, AutoConfig
+from ..sam2.configuration_sam2 import (
+ Sam2MaskDecoderConfig,
+ Sam2PromptEncoderConfig,
+)
+from ..sam2.modeling_sam2 import (
+ Sam2FeedForward,
+ Sam2ImageSegmentationOutput,
+ Sam2LayerNorm,
+ Sam2Model,
+ Sam2SinePositionEmbedding,
+ Sam2TwoWayAttentionBlock,
+ eager_attention_forward,
+)
+from ..sam2.processing_sam2 import Sam2Processor
+
+
+logger = logging.get_logger(__name__)
+
+
+class Sam2VideoPromptEncoderConfig(Sam2PromptEncoderConfig):
+ pass
+
+
+class Sam2VideoMaskDecoderConfig(Sam2MaskDecoderConfig):
+ pass
+
+
+class Sam2VideoConfig(PretrainedConfig):
+ r"""
+ [`Sam2Config`] is the configuration class to store the configuration of a [`Sam2Model`]. It is used to instantiate a
+ SAM2 model according to the specified arguments, defining the memory attention, memory encoder, and image encoder
+ configs. Instantiating a configuration defaults will yield a similar configuration to that of the SAM 2.1 Hiera-tiny
+ [facebook/sam2.1-hiera-tiny](https://huggingface.co/facebook/sam2.1-hiera-tiny) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (Union[`dict`, `Sam2VisionConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`Sam2VisionConfig`].
+ prompt_encoder_config (Union[`dict`, `Sam2PromptEncoderConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`Sam2PromptEncoderConfig`].
+ mask_decoder_config (Union[`dict`, `Sam2MaskDecoderConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`Sam2MaskDecoderConfig`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ Standard deviation for parameter initialization.
+ num_maskmem (`int`, *optional*, defaults to 7):
+ The number of memory slots for the mask memory.
+ image_size (`int`, *optional*, defaults to 1024):
+ The size of the input images.
+ sigmoid_scale_for_mem_enc (`float`, *optional*, defaults to 20.0):
+ Scale factor for the sigmoid function in the memory encoder.
+ sigmoid_bias_for_mem_enc (`float`, *optional*, defaults to -10.0):
+ Bias for the sigmoid function in the memory encoder.
+ enable_occlusion_spatial_embedding (`bool`, *optional*, defaults to `True`):
+ Whether to enable spatial embedding for occlusions.
+ multimask_output_in_sam (`bool`, *optional*, defaults to `True`):
+ Whether to output multiple masks from the SAM head.
+ multimask_min_pt_num (`int`, *optional*, defaults to 0):
+ The minimum number of points to trigger multimask output.
+ multimask_max_pt_num (`int`, *optional*, defaults to 1):
+ The maximum number of points to trigger multimask output.
+ multimask_output_for_tracking (`bool`, *optional*, defaults to `True`):
+ Whether to use multimask output for tracking.
+ max_object_pointers_in_encoder (`int`, *optional*, defaults to 16):
+ The maximum number of object pointers in the encoder.
+ enable_temporal_pos_encoding_for_object_pointers (`bool`, *optional*, defaults to `True`):
+ Whether to enable temporal positional encoding for object pointers.
+ memory_attention_hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the memory attention hidden states.
+ memory_attention_num_layers (`int`, *optional*, defaults to 4):
+ The number of layers in the memory attention module.
+ memory_attention_num_attention_heads (`int`, *optional*, defaults to 1):
+ Number of attention heads for each attention layer in the memory attention.
+ memory_attention_downsample_rate (`int`, *optional*, defaults to 1):
+ The downsample rate for the attention layers.
+ memory_attention_feed_forward_hidden_size (`int`, *optional*, defaults to 2048):
+ The dimension of the feedforward network in the memory attention module.
+ memory_attention_feed_forward_hidden_act (`str`, *optional*, defaults to `"relu"`):
+ The non-linear activation function in the feedforward network in the memory attention module.
+ memory_attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout rate for the memory attention module.
+ memory_attention_rope_theta (`float`, *optional*, defaults to 10000):
+ The Rope theta parameter.
+ memory_attention_rope_feat_sizes (`list[int]`, *optional*, defaults to `[64, 64]`):
+ The feature sizes for the Rope positional encoding.
+ memory_attention_rope_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout rate for the Rope positional encoding.
+ memory_encoder_hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the memory encoder hidden states.
+ memory_encoder_output_channels (`int`, *optional*, defaults to 64):
+ The number of output channels for the memory encoder.
+ mask_downsampler_embed_dim (`int`, *optional*, defaults to 256):
+ The dimension of the mask downsampler embedding.
+ mask_downsampler_kernel_size (`int`, *optional*, defaults to 3):
+ The kernel size for the mask downsampler.
+ mask_downsampler_stride (`int`, *optional*, defaults to 2):
+ The stride for the mask downsampler.
+ mask_downsampler_padding (`int`, *optional*, defaults to 1):
+ The padding for the mask downsampler.
+ mask_downsampler_total_stride (`int`, *optional*, defaults to 16):
+ The total stride for the mask downsampler.
+ mask_downsampler_hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the mask downsampler.
+ memory_fuser_num_layers (`int`, *optional*, defaults to 2):
+ The number of layers in the memory fuser.
+ memory_fuser_embed_dim (`int`, *optional*, defaults to 256):
+ The dimension of the embedding layer in the memory fuser.
+ memory_fuser_intermediate_dim (`int`, *optional*, defaults to 1024):
+ The dimension of the intermediate layer in the memory fuser.
+ memory_fuser_kernel_size (`int`, *optional*, defaults to 7):
+ The kernel size for the memory fuser.
+ memory_fuser_padding (`int`, *optional*, defaults to 3):
+ The padding for the memory fuser.
+ memory_fuser_layer_scale_init_value (`float`, *optional*, defaults to 1e-06):
+ The initial value for the layer scale in the memory fuser.
+ memory_fuser_hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the memory fuser.
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... Sam2VisionConfig,
+ ... Sam2PromptEncoderConfig,
+ ... Sam2MaskDecoderConfig,
+ ... Sam2Model,
+ ... )
+
+ >>> # Initializing a Sam2Config with `"facebook/sam2.1_hiera_tiny"` style configuration
+ >>> configuration = Sam2config()
+
+ >>> # Initializing a Sam2Model (with random weights) from the `"facebook/sam2.1_hiera_tiny"` style configuration
+ >>> model = Sam2Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+
+ >>> # We can also initialize a Sam2Config from a Sam2VisionConfig, Sam2PromptEncoderConfig, and Sam2MaskDecoderConfig
+
+ >>> # Initializing SAM2 vision encoder, memory attention, and memory encoder configurations
+ >>> vision_config = Sam2VisionConfig()
+ >>> prompt_encoder_config = Sam2PromptEncoderConfig()
+ >>> mask_decoder_config = Sam2MaskDecoderConfig()
+
+ >>> config = Sam2Config(vision_config, prompt_encoder_config, mask_decoder_config)
+ ```"""
+
+ model_type = "sam2_video"
+ sub_configs = {
+ "vision_config": AutoConfig,
+ "prompt_encoder_config": Sam2VideoPromptEncoderConfig,
+ "mask_decoder_config": Sam2VideoMaskDecoderConfig,
+ }
+
+ def __init__(
+ self,
+ vision_config=None,
+ prompt_encoder_config=None,
+ mask_decoder_config=None,
+ initializer_range=0.02,
+ num_maskmem=7,
+ image_size=1024,
+ sigmoid_scale_for_mem_enc=20.0,
+ sigmoid_bias_for_mem_enc=-10.0,
+ enable_occlusion_spatial_embedding=True,
+ multimask_output_in_sam=True,
+ multimask_min_pt_num=0,
+ multimask_max_pt_num=1,
+ multimask_output_for_tracking=True,
+ max_object_pointers_in_encoder=16,
+ enable_temporal_pos_encoding_for_object_pointers=True,
+ # memory attention
+ memory_attention_hidden_size=256,
+ memory_attention_num_layers=4,
+ memory_attention_num_attention_heads=1,
+ memory_attention_downsample_rate=1,
+ memory_attention_feed_forward_hidden_size=2048,
+ memory_attention_feed_forward_hidden_act="relu",
+ memory_attention_dropout=0.1,
+ memory_attention_rope_theta=10000,
+ memory_attention_rope_feat_sizes=None,
+ memory_attention_rope_dropout=0.1,
+ # memory encoder
+ memory_encoder_hidden_size=256,
+ memory_encoder_output_channels=64,
+ mask_downsampler_embed_dim=256,
+ mask_downsampler_kernel_size=3,
+ mask_downsampler_stride=2,
+ mask_downsampler_padding=1,
+ mask_downsampler_total_stride=16,
+ mask_downsampler_hidden_act="gelu",
+ memory_fuser_num_layers=2,
+ memory_fuser_embed_dim=256,
+ memory_fuser_intermediate_dim=1024,
+ memory_fuser_kernel_size=7,
+ memory_fuser_padding=3,
+ memory_fuser_layer_scale_init_value=1e-6,
+ memory_fuser_hidden_act="gelu",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ vision_config = vision_config if vision_config is not None else {}
+ prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {}
+ mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {}
+ memory_attention_rope_feat_sizes = (
+ [64, 64] if memory_attention_rope_feat_sizes is None else memory_attention_rope_feat_sizes
+ )
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "sam2_vision_model")
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ if isinstance(prompt_encoder_config, Sam2VideoPromptEncoderConfig):
+ prompt_encoder_config = prompt_encoder_config.to_dict()
+ if isinstance(mask_decoder_config, Sam2VideoMaskDecoderConfig):
+ mask_decoder_config = mask_decoder_config.to_dict()
+
+ self.vision_config = vision_config
+ self.prompt_encoder_config = Sam2VideoPromptEncoderConfig(**prompt_encoder_config)
+ self.mask_decoder_config = Sam2VideoMaskDecoderConfig(**mask_decoder_config)
+
+ self.initializer_range = initializer_range
+ self.num_maskmem = num_maskmem # default 1 input frame + 6 previous frames
+ self.image_size = image_size
+ self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
+ self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
+ self.multimask_output_in_sam = multimask_output_in_sam
+ self.multimask_min_pt_num = multimask_min_pt_num
+ self.multimask_max_pt_num = multimask_max_pt_num
+ self.multimask_output_for_tracking = multimask_output_for_tracking
+ self.max_object_pointers_in_encoder = max_object_pointers_in_encoder
+ # The next 4 are True for sam2.1 and False for sam2
+ self.enable_occlusion_spatial_embedding = enable_occlusion_spatial_embedding
+ self.enable_temporal_pos_encoding_for_object_pointers = enable_temporal_pos_encoding_for_object_pointers
+
+ # memory attention
+ self.memory_attention_hidden_size = memory_attention_hidden_size
+ self.memory_attention_num_layers = memory_attention_num_layers
+ self.memory_attention_num_attention_heads = memory_attention_num_attention_heads
+ self.memory_attention_downsample_rate = memory_attention_downsample_rate
+ self.memory_attention_feed_forward_hidden_size = memory_attention_feed_forward_hidden_size
+ self.memory_attention_feed_forward_hidden_act = memory_attention_feed_forward_hidden_act
+ self.memory_attention_dropout = memory_attention_dropout
+ self.memory_attention_rope_theta = memory_attention_rope_theta
+ self.memory_attention_rope_feat_sizes = memory_attention_rope_feat_sizes
+ self.memory_attention_rope_dropout = memory_attention_rope_dropout
+
+ # memory encoder
+ self.memory_encoder_hidden_size = memory_encoder_hidden_size
+ self.memory_encoder_output_channels = memory_encoder_output_channels
+ self.mask_downsampler_embed_dim = mask_downsampler_embed_dim
+ self.mask_downsampler_kernel_size = mask_downsampler_kernel_size
+ self.mask_downsampler_stride = mask_downsampler_stride
+ self.mask_downsampler_padding = mask_downsampler_padding
+ self.mask_downsampler_total_stride = mask_downsampler_total_stride
+ self.mask_downsampler_hidden_act = mask_downsampler_hidden_act
+ self.memory_fuser_num_layers = memory_fuser_num_layers
+ self.memory_fuser_embed_dim = memory_fuser_embed_dim
+ self.memory_fuser_intermediate_dim = memory_fuser_intermediate_dim
+ self.memory_fuser_kernel_size = memory_fuser_kernel_size
+ self.memory_fuser_padding = memory_fuser_padding
+ self.memory_fuser_layer_scale_init_value = memory_fuser_layer_scale_init_value
+ self.memory_fuser_hidden_act = memory_fuser_hidden_act
+
+
+class Sam2VideoInferenceCache:
+ """Cache for vision features and model constants."""
+
+ def __init__(
+ self,
+ inference_device: Union[torch.device, str] = "cpu",
+ inference_state_device: Union[torch.device, str] = "cpu",
+ max_vision_features_cache_size: int = 1,
+ ):
+ self.inference_device = inference_device
+ self.inference_state_device = inference_state_device
+ self.max_vision_features_cache_size = max_vision_features_cache_size
+
+ self._vision_features = {}
+
+ def cache_vision_features(self, frame_idx: int, features: dict):
+ """Cache vision features with automatic device management."""
+ cached = {}
+ if len(self._vision_features) >= self.max_vision_features_cache_size:
+ # remove the oldest frame
+ self._vision_features.pop(min(self._vision_features.keys()))
+
+ for key, value in features.items():
+ if isinstance(value, torch.Tensor):
+ cached[key] = value.to(self.inference_state_device, non_blocking=True)
+ elif isinstance(value, (list, tuple)) and value and isinstance(value[0], torch.Tensor):
+ cached[key] = [v.to(self.inference_state_device, non_blocking=True) for v in value]
+ else:
+ cached[key] = value
+ self._vision_features[frame_idx] = cached
+
+ def get_vision_features(self, frame_idx: int) -> Optional[dict]:
+ """Get cached vision features, automatically moved to inference device."""
+ if frame_idx not in self._vision_features:
+ return None
+
+ cached = self._vision_features[frame_idx]
+ moved = {}
+ for key, value in cached.items():
+ if isinstance(value, torch.Tensor):
+ moved[key] = value.to(self.inference_device, non_blocking=True)
+ elif isinstance(value, (list, tuple)) and value and isinstance(value[0], torch.Tensor):
+ moved[key] = [v.to(self.inference_device, non_blocking=True) for v in value]
+ else:
+ moved[key] = value
+ return moved
+
+ def clear_all(self):
+ """Clear all cached data."""
+ self._vision_features.clear()
+
+
+class Sam2VideoInferenceSession:
+ r"""
+ Manages video inference session parameters, state and cache.
+
+ Args:
+ video (`torch.FloatTensor`, *optional*):
+ The video to process. No need to provide when streaming.
+ video_height (`int`, *optional*):
+ The height of the video.
+ video_width (`int`, *optional*):
+ The width of the video.
+ inference_device (`torch.device`, *optional*, defaults to `"cpu"`):
+ The device to use for inference.
+ inference_state_device (`torch.device`, *optional*, defaults to `"cpu"`):
+ The device to store the inference state on.
+ video_storage_device (`torch.device`, *optional*, defaults to `"cpu"`):
+ The device to store the video on.
+ dtype (`torch.dtype`, *optional*, defaults to `"float32"`):
+ The dtype to use for the video.
+ max_vision_features_cache_size (`int`, *optional*, defaults to 1):
+ The maximum number of vision features to cache.
+ """
+
+ def __init__(
+ self,
+ video: Optional[torch.FloatTensor] = None,
+ video_height: Optional[int] = None,
+ video_width: Optional[int] = None,
+ inference_device: Union[torch.device, str] = "cpu",
+ inference_state_device: Union[torch.device, str] = "cpu",
+ video_storage_device: Union[torch.device, str] = "cpu",
+ dtype: Union[torch.dtype, str] = "float32",
+ max_vision_features_cache_size: int = 1,
+ ):
+ # store as a dictionary to avoid double memory allocation with torch.cat when adding new frames
+ self.processed_frames = (
+ dict(enumerate(video.to(video_storage_device, dtype=dtype))) if video is not None else None
+ )
+ self.video_height = video_height
+ self.video_width = video_width
+
+ self.inference_device = inference_device
+ self.inference_state_device = inference_state_device
+ self.video_storage_device = video_storage_device
+ self.dtype = dtype
+ self.max_vision_features_cache_size = max_vision_features_cache_size
+
+ # Cache for computed features
+ self.cache = Sam2VideoInferenceCache(
+ inference_device=self.inference_device,
+ inference_state_device=self.inference_state_device,
+ max_vision_features_cache_size=self.max_vision_features_cache_size,
+ )
+
+ # Persistent object tracking state
+ self._obj_id_to_idx = OrderedDict()
+ self._obj_idx_to_id = OrderedDict()
+ self.obj_ids = []
+
+ # Persistent user inputs
+ self.point_inputs_per_obj = {}
+ self.mask_inputs_per_obj = {}
+
+ # Persistent model outputs/history
+ self.output_dict_per_obj = {}
+ self.frames_tracked_per_obj = {}
+
+ # Session state flags
+ self.obj_with_new_inputs = []
+
+ @property
+ def num_frames(self) -> Optional[int]:
+ return len(self.processed_frames) if self.processed_frames is not None else None
+
+ # Object management
+ def obj_id_to_idx(self, obj_id: int) -> int:
+ """Map object ID to index, creating new entry if needed."""
+ obj_idx = self._obj_id_to_idx.get(obj_id, None)
+ if obj_idx is not None:
+ return obj_idx
+
+ obj_idx = len(self._obj_id_to_idx)
+ self._obj_id_to_idx[obj_id] = obj_idx
+ self._obj_idx_to_id[obj_idx] = obj_id
+ self.obj_ids = list(self._obj_id_to_idx)
+
+ self.point_inputs_per_obj[obj_idx] = {}
+ self.mask_inputs_per_obj[obj_idx] = {}
+ self.output_dict_per_obj[obj_idx] = {
+ "cond_frame_outputs": {},
+ "non_cond_frame_outputs": {},
+ }
+ self.frames_tracked_per_obj[obj_idx] = {}
+
+ return obj_idx
+
+ # Video Inference specific functions
+ def obj_idx_to_id(self, obj_idx: int) -> int:
+ """Map model-side object index to client-side object id."""
+ return self._obj_idx_to_id[obj_idx]
+
+ def get_obj_num(self) -> int:
+ """Get the total number of unique object ids received so far in this session."""
+ return len(self._obj_idx_to_id)
+
+ # Input management with device handling
+ def add_point_inputs(self, obj_idx: int, frame_idx: int, inputs: dict):
+ """Add point inputs with automatic device placement."""
+ device_inputs = {}
+ for key, value in inputs.items():
+ if isinstance(value, torch.Tensor):
+ device_inputs[key] = value.to(self.inference_device, non_blocking=True)
+ else:
+ device_inputs[key] = value
+ self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs
+
+ def remove_point_inputs(self, obj_idx: int, frame_idx: int):
+ """Remove point inputs."""
+ self.point_inputs_per_obj[obj_idx].pop(frame_idx, None)
+
+ def add_mask_inputs(self, obj_idx: int, frame_idx: int, inputs: torch.Tensor):
+ """Add mask inputs with automatic device placement."""
+ self.mask_inputs_per_obj[obj_idx][frame_idx] = inputs.to(
+ self.inference_device, dtype=self.dtype, non_blocking=True
+ )
+
+ def remove_mask_inputs(self, obj_idx: int, frame_idx: int):
+ """Remove mask inputs."""
+ self.mask_inputs_per_obj[obj_idx].pop(frame_idx, None)
+
+ # Output management with smart device placement
+ def store_output(
+ self,
+ obj_idx: int,
+ frame_idx: int,
+ output_key: Optional[str] = None,
+ output_value: Optional[Union[torch.Tensor, dict]] = None,
+ is_conditioning_frame: bool = True,
+ ):
+ """
+ Store output with smart device management.
+ If output_key is None, the output is stored as a dictionary.
+
+ Args:
+ obj_idx (int): The index of the object.
+ frame_idx (int): The index of the frame.
+ output_key (Optional[str]): The key of the output. If None, the output is stored as a dictionary.
+ output_value (Optional[Union[torch.Tensor, dict]]): The value of the output.
+ is_conditioning_frame (bool): Whether the output is for a conditioning frame.
+ """
+ storage_key = "cond_frame_outputs" if is_conditioning_frame else "non_cond_frame_outputs"
+
+ if output_key is None and isinstance(output_value, dict):
+ self.output_dict_per_obj[obj_idx][storage_key][frame_idx] = {}
+ for key, value in output_value.items():
+ self.store_output(obj_idx, frame_idx, key, value, is_conditioning_frame)
+ return
+
+ # Device placement: small tensors stay on inference device, large ones go to inference state device
+ if output_key in ["object_pointer", "object_score_logits"]: # Small tensors
+ self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value
+ elif isinstance(output_value, torch.Tensor): # Large tensors like masks, features
+ self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value.to(
+ self.inference_state_device, non_blocking=True
+ )
+ else:
+ self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value
+
+ def get_output(
+ self,
+ obj_idx: int,
+ frame_idx: int,
+ output_key: str,
+ is_conditioning_frame: bool = True,
+ ):
+ """
+ Get output with smart device management.
+
+ Args:
+ obj_idx (int): The index of the object.
+ frame_idx (int): The index of the frame.
+ output_key (str): The key of the output.
+ is_conditioning_frame (bool): Whether the output is for a conditioning frame.
+ """
+ storage_key = "cond_frame_outputs" if is_conditioning_frame else "non_cond_frame_outputs"
+ out = self.output_dict_per_obj[obj_idx][storage_key].get(frame_idx, None)
+ # move to inference device if needed
+ if out is None:
+ return None
+ value = out[output_key]
+ if isinstance(value, torch.Tensor):
+ value = value.to(self.inference_device, non_blocking=True)
+ return value
+
+ # Video frame management
+ def add_new_frame(self, pixel_values: torch.Tensor, frame_idx: Optional[int] = None) -> int:
+ """Add new frame with automatic device placement."""
+ pixel_values = pixel_values.to(self.video_storage_device, dtype=self.dtype, non_blocking=True)
+ if pixel_values.dim() == 4:
+ pixel_values = pixel_values.squeeze(0)
+
+ if frame_idx is None:
+ frame_idx = len(self.processed_frames) if self.processed_frames is not None else 0
+
+ if self.processed_frames is None:
+ self.processed_frames = {frame_idx: pixel_values}
+ else:
+ self.processed_frames[frame_idx] = pixel_values
+
+ return frame_idx
+
+ def get_frame(self, frame_idx: int) -> torch.Tensor:
+ """Get frame from video."""
+ return self.processed_frames[frame_idx].to(self.inference_device, non_blocking=True)
+
+ def reset_tracking_data(self):
+ """Reset tracking data but keep cache."""
+ self._obj_id_to_idx.clear()
+ self._obj_idx_to_id.clear()
+ self.obj_ids.clear()
+ self.point_inputs_per_obj.clear()
+ self.mask_inputs_per_obj.clear()
+ self.output_dict_per_obj.clear()
+ self.frames_tracked_per_obj.clear()
+ self.obj_with_new_inputs = []
+ # Note: cache and video data are preserved
+
+ def reset_inference_session(self):
+ """Reset tracking data and cache."""
+ self._obj_id_to_idx.clear()
+ self._obj_idx_to_id.clear()
+ self.obj_ids.clear()
+ self.point_inputs_per_obj.clear()
+ self.mask_inputs_per_obj.clear()
+ self.output_dict_per_obj.clear()
+ self.frames_tracked_per_obj.clear()
+ self.obj_with_new_inputs = []
+ self.cache.clear_all()
+
+
+class Sam2VideoProcessor(Sam2Processor):
+ r"""
+ Constructs a SAM2 processor which wraps a SAM2 image processor and an 2D points & Bounding boxes processor into a
+ single processor.
+
+ [`Sam2VideoProcessor`] offers all the functionalities of [`Sam2ImageProcessorFast`] and [`Sam2VideoProcessor`]. See the docstring of
+ [`~Sam2ImageProcessorFast.__call__`] and [`~Sam2VideoProcessor.__call__`] for more information.
+
+ Args:
+ image_processor (`Sam2ImageProcessorFast`):
+ An instance of [`Sam2ImageProcessorFast`].
+ video_processor (`Sam2VideoVideoProcessor`):
+ An instance of [`Sam2VideoVideoProcessor`].
+ target_size (`int`, *optional*):
+ The target size (target_size, target_size) to which the image will be resized.
+ point_pad_value (`int`, *optional*, defaults to -10):
+ The value used for padding input points.
+ """
+
+ attributes = ["image_processor", "video_processor"]
+ image_processor_class = "Sam2ImageProcessorFast"
+ video_processor_class = "Sam2VideoVideoProcessor"
+
+ def __init__(
+ self, image_processor, video_processor, target_size: Optional[int] = None, point_pad_value: int = -10, **kwargs
+ ):
+ ProcessorMixin.__init__(self, image_processor, video_processor, **kwargs)
+ self.point_pad_value = point_pad_value
+ self.target_size = target_size if target_size is not None else self.image_processor.size["height"]
+
+ def init_video_session(
+ self,
+ video: Optional[VideoInput] = None,
+ inference_device: Union[str, "torch.device"] = "cpu",
+ inference_state_device: Union[str, "torch.device"] = None,
+ processing_device: Union[str, "torch.device"] = None,
+ video_storage_device: Union[str, "torch.device"] = None,
+ max_vision_features_cache_size: int = 1,
+ dtype: torch.dtype = torch.float32,
+ ):
+ """
+ Initializes a video session for inference.
+ If a video is provided (async inference), the video will be processed and stored on the `video_storage_device`.
+
+ Args:
+ video (`VideoInput`, *optional*):
+ The video to process. No need to provide when streaming.
+ inference_device (`str` or `torch.device`, *optional*, defaults to "cpu"):
+ The device to use for inference.
+ inference_state_device (`str` or `torch.device`, *optional*):
+ The device to store the inference state on.
+ processing_device (`str` or `torch.device`, *optional*):
+ The device to use for video processing.
+ video_storage_device (`str` or `torch.device`, *optional*):
+ The device to store the processed video frames on.
+ max_vision_features_cache_size (`int`, *optional*, defaults to 1):
+ The maximum number of vision features to cache.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ The torch dtype to use for the whole session.
+ """
+ video_storage_device = video_storage_device if video_storage_device is not None else inference_device
+ inference_state_device = inference_state_device if inference_state_device is not None else inference_device
+ processing_device = processing_device if processing_device is not None else inference_device
+ pixel_values_video = None
+ video_height = None
+ video_width = None
+ if video is not None:
+ processed_video = self.video_processor(videos=video, device=processing_device, return_tensors="pt")
+ pixel_values_video = processed_video.pixel_values_videos[0]
+ video_height = processed_video.original_sizes[0][0]
+ video_width = processed_video.original_sizes[0][1]
+ inference_session = Sam2VideoInferenceSession(
+ video=pixel_values_video,
+ video_height=video_height,
+ video_width=video_width,
+ inference_device=inference_device,
+ video_storage_device=video_storage_device,
+ inference_state_device=inference_state_device,
+ dtype=dtype,
+ max_vision_features_cache_size=max_vision_features_cache_size,
+ )
+ return inference_session
+
+ def add_inputs_to_inference_session(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ frame_idx: int,
+ obj_ids: Union[list[int], int],
+ input_points: Optional[Union[list[list[list[list[float]]]], torch.Tensor]] = None,
+ input_labels: Optional[Union[list[list[list[int]]], torch.Tensor]] = None,
+ input_boxes: Optional[Union[list[list[list[float]]], torch.Tensor]] = None,
+ input_masks: Optional[Union[np.ndarray, torch.Tensor, list[np.ndarray], list[torch.Tensor]]] = None,
+ original_size: Optional[tuple[int, int]] = None,
+ clear_old_inputs: bool = True,
+ ) -> Sam2VideoInferenceSession:
+ """
+ Process new points, boxes, or masks for a video frame and add them to the inference session.
+
+ Args:
+ inference_session (`Sam2VideoInferenceSession`):
+ The inference session for the video.
+ frame_idx (`int`):
+ The index of the frame to process.
+ obj_ids (`list[int]` or `int`):
+ The object ID(s) to associate with the points or box.
+ These can be any integers and can be reused later on to specify an object.
+ input_points (`list[list[list[list[float]]]]`, `torch.Tensor`, *optional*):
+ The points to add to the frame.
+ input_labels (`list[list[list[int]]]`, `torch.Tensor`, *optional*):
+ The labels for the points.
+ input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*):
+ The bounding boxes to add to the frame.
+ input_masks (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, or `list[torch.Tensor]`, *optional*):
+ The mask(s) to add to the frame.
+ original_size (`tuple[int, int]`, *optional*):
+ The original size of the video. Provide when streaming.
+ clear_old_inputs (`bool`, *optional*, defaults to `True`):
+ Whether to clear old inputs for the object.
+ """
+
+ if isinstance(obj_ids, int):
+ obj_ids = [obj_ids]
+
+ # Validate inputs
+ if (input_points is not None) != (input_labels is not None):
+ raise ValueError("points and labels must be provided together")
+ if input_points is None and input_boxes is None and input_masks is None:
+ raise ValueError("at least one of points, boxes, or masks must be provided as input")
+ if input_masks is not None and (input_points is not None or input_boxes is not None):
+ raise ValueError("masks cannot be provided together with points or boxes")
+
+ if input_masks is not None:
+ return self.process_new_mask_for_video_frame(inference_session, frame_idx, obj_ids, input_masks)
+ else:
+ return self.process_new_points_or_boxes_for_video_frame(
+ inference_session,
+ frame_idx,
+ obj_ids,
+ input_points,
+ input_labels,
+ input_boxes,
+ original_size,
+ clear_old_inputs,
+ )
+
+ def process_new_points_or_boxes_for_video_frame(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ frame_idx: int,
+ obj_ids: list[int],
+ input_points: Optional[Union[list[list[list[list[float]]]], torch.Tensor]] = None,
+ input_labels: Optional[Union[list[list[list[int]]], torch.Tensor]] = None,
+ input_boxes: Optional[Union[list[list[list[float]]], torch.Tensor]] = None,
+ original_size: Optional[tuple[int, int]] = None,
+ clear_old_inputs: bool = True,
+ ) -> Sam2VideoInferenceSession:
+ """
+ Process new points or boxes for a video frame and add them to the inference session.
+
+ Args:
+ inference_session (`Sam2VideoInferenceSession`):
+ The inference session for the video.
+ frame_idx (`int`):
+ The index of the frame to process.
+ obj_ids (`list[int]`):
+ The object ID(s) to associate with the points or box.
+ These can be any integers and can be reused later on to specify an object.
+ input_points (`list[list[list[list[float]]]]`, `torch.Tensor`, *optional*):
+ The points to add to the frame.
+ input_labels (`list[list[list[int]]]`, `torch.Tensor`, *optional*):
+ The labels for the points.
+ input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*):
+ The bounding boxes to add to the frame.
+ original_size (`tuple[int, int]`, *optional*):
+ The original size of the video. Provide when streaming.
+ clear_old_inputs (`bool`, *optional*, defaults to `True`):
+ Whether to clear old inputs for the object.
+ """
+ if original_size is not None:
+ inference_session.video_height = original_size[0]
+ inference_session.video_width = original_size[1]
+ elif inference_session.video_height is None or inference_session.video_width is None:
+ raise ValueError("original_size must be provided when adding points or boxes on a first streamed frame")
+
+ original_sizes = [[inference_session.video_height, inference_session.video_width]]
+
+ encoded_inputs = self(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ original_sizes=original_sizes,
+ return_tensors="pt",
+ )
+ input_points = encoded_inputs.get("input_points", None)
+ input_labels = encoded_inputs.get("input_labels", None)
+ input_boxes = encoded_inputs.get("input_boxes", None)
+
+ if input_points is not None:
+ if input_points.shape[1] != len(obj_ids):
+ raise ValueError(
+ f"Number of object ids ({len(obj_ids)}) does not match number of points ({input_points.shape[1]})"
+ )
+ else:
+ input_points = torch.zeros(1, len(obj_ids), 0, 2, dtype=torch.float32)
+ if input_labels is not None:
+ if input_labels.shape[1] != len(obj_ids):
+ raise ValueError(
+ f"Number of object ids ({len(obj_ids)}) does not match number of labels ({input_labels.shape[1]})"
+ )
+ else:
+ input_labels = torch.zeros(1, len(obj_ids), 0, dtype=torch.int32)
+ if input_boxes is not None:
+ if input_boxes.shape[1] != len(obj_ids):
+ raise ValueError(
+ f"Number of object ids ({len(obj_ids)}) does not match number of boxes ({input_boxes.shape[1]})"
+ )
+
+ if input_boxes is not None:
+ if not clear_old_inputs:
+ raise ValueError(
+ "cannot add box without clearing old points, since "
+ "box prompt must be provided before any point prompt "
+ "(please use clear_old_points=True instead)"
+ )
+ box_coords = input_boxes.reshape(1, -1, 2, 2)
+ box_labels = torch.tensor([2, 3], dtype=torch.int32).repeat(1, box_coords.shape[1], 1)
+ input_points = torch.cat([box_coords, input_points], dim=2)
+ input_labels = torch.cat([box_labels, input_labels], dim=2)
+
+ for obj_id, idx in zip(obj_ids, range(len(obj_ids))):
+ obj_idx = inference_session.obj_id_to_idx(obj_id)
+ input_points_for_obj = input_points[:, idx, :, :].unsqueeze(1)
+ input_labels_for_obj = input_labels[:, idx, :].unsqueeze(1)
+ # Handle existing points
+ if not clear_old_inputs:
+ existing_points = inference_session.point_inputs_per_obj[obj_idx].get(frame_idx, None)
+ if existing_points is not None:
+ # Concatenate with existing points
+ input_points_for_obj = torch.cat(
+ [existing_points["point_coords"].to(input_points_for_obj.device), input_points_for_obj], dim=2
+ )
+ input_labels_for_obj = torch.cat(
+ [existing_points["point_labels"].to(input_labels_for_obj.device), input_labels_for_obj], dim=2
+ )
+ point_inputs = {
+ "point_coords": input_points_for_obj,
+ "point_labels": input_labels_for_obj,
+ }
+
+ inference_session.add_point_inputs(obj_idx, frame_idx, point_inputs)
+ inference_session.remove_mask_inputs(obj_idx, frame_idx) # Clear any mask inputs
+
+ inference_session.obj_with_new_inputs = obj_ids
+
+ def process_new_mask_for_video_frame(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ frame_idx: int,
+ obj_ids: list[int],
+ input_masks: Union[np.ndarray, torch.Tensor, list[np.ndarray], list[torch.Tensor]],
+ ):
+ """
+ Add new mask to a frame and add them to the inference session.
+
+ Args:
+ inference_session (`Sam2VideoInferenceSession`):
+ The inference session for the video.
+ frame_idx (`int`):
+ The index of the frame to process.
+ obj_ids (`list[int]`):
+ The object ID(s) to associate with the mask.
+ These can be any integers and can be reused later on to specify an object.
+ input_masks (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, or `list[torch.Tensor]`):
+ The mask(s) to add to the frame.
+ """
+ if not isinstance(input_masks, list):
+ input_masks = [input_masks]
+ if len(input_masks) != len(obj_ids):
+ raise ValueError(
+ f"Number of object ids ({len(obj_ids)}) does not match number of masks ({len(input_masks)})"
+ )
+
+ for obj_id, mask in zip(obj_ids, input_masks):
+ obj_idx = inference_session.obj_id_to_idx(obj_id)
+
+ device = inference_session.inference_device
+
+ # Process mask
+ if not isinstance(mask, torch.Tensor):
+ mask = torch.tensor(mask, dtype=torch.bool)
+ nb_dim = mask.dim()
+ if nb_dim > 4 or nb_dim < 2:
+ raise ValueError(f"Mask has an unsupported number of dimensions: {nb_dim}")
+ for i in range(4 - nb_dim):
+ mask = mask.unsqueeze(0)
+
+ mask_H, mask_W = mask.shape[-2:]
+ mask_inputs_orig = mask.to(device)
+ mask_inputs_orig = mask_inputs_orig.float().to(device)
+
+ # Resize mask if needed
+ if mask_H != self.target_size or mask_W != self.target_size:
+ mask_inputs = torch.nn.functional.interpolate(
+ mask_inputs_orig,
+ size=(self.target_size, self.target_size),
+ align_corners=False,
+ mode="bilinear",
+ antialias=True,
+ )
+ mask_inputs = (mask_inputs >= 0.5).float()
+ else:
+ mask_inputs = mask_inputs_orig
+
+ inference_session.add_mask_inputs(obj_idx, frame_idx, mask_inputs)
+ inference_session.remove_point_inputs(obj_idx, frame_idx) # Clear any point inputs
+
+ inference_session.obj_with_new_inputs = obj_ids
+
+
+class Sam2VideoLayerNorm(Sam2LayerNorm):
+ pass
+
+
+class Sam2VideoPositionEmbeddingSine(Sam2SinePositionEmbedding):
+ pass
+
+
+class Sam2VideoTwoWayAttentionBlock(Sam2TwoWayAttentionBlock):
+ pass
+
+
+class Sam2VideoFeedForward(Sam2FeedForward):
+ pass
+
+
+class Sam2VideoImageSegmentationOutput(Sam2ImageSegmentationOutput):
+ r"""
+ iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
+ The Intersection over Union (IoU) scores of the predicted masks.
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
+ The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
+ by the processor to be brought to the original image size.
+ object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
+ Logits for the object score, indicating if an object is present.
+ image_embeddings (`tuple(torch.FloatTensor)`):
+ The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
+ tensor has shape `(batch_size, channels, height, width)`.
+ vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
+ Hidden-states of the vision model at the output of each stage.
+ vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
+ Attentions weights of the vision model.
+ mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
+ Attentions weights of the mask decoder.
+ high_res_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, image_size, image_size)`, *optional*):
+ The predicted masks, upscaled to the original image size. Only used for Sam2VideoModel.
+ object_pointer (`torch.FloatTensor` of shape `(batch_size, point_batch_size, hidden_size)`, *optional*):
+ A tensor representing the object pointer, used for tracking in videos. Only used for Sam2VideoModel.
+ """
+
+ high_res_masks: Optional[torch.FloatTensor] = None
+ object_pointer: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(custom_intro="Base class for the Sam2 model's output.")
+class Sam2VideoSegmentationOutput(ModelOutput):
+ r"""
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
+ The predicted masks stored at the model's resolution.
+ frame_idx (`int`):
+ The frame index of the video.
+ """
+
+ pred_masks: Optional[torch.FloatTensor] = None
+ frame_idx: Optional[int] = None
+
+
+@auto_docstring
+class Sam2VideoPreTrainedModel(PreTrainedModel):
+ config_class = Sam2VideoConfig
+ base_model_prefix = "sam2_video"
+ main_input_name = "pixel_values"
+ _supports_sdpa = True
+ _supports_flash_attn_2 = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, (nn.LayerNorm, Sam2VideoLayerNorm)):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+ elif isinstance(module, Sam2VideoModel):
+ if module.no_memory_positional_encoding is not None:
+ module.no_memory_positional_encoding.data.zero_()
+ if module.memory_temporal_positional_encoding is not None:
+ module.memory_temporal_positional_encoding.data.zero_()
+ if module.no_object_pointer is not None:
+ module.no_object_pointer.data.zero_()
+ if module.occlusion_spatial_embedding_parameter is not None:
+ module.occlusion_spatial_embedding_parameter.data.zero_()
+ if isinstance(module, Sam2VideoMemoryFuserCXBlock):
+ if module.scale is not None:
+ module.scale.data.zero_()
+
+
+class Sam2VideoVisionRotaryEmbedding(nn.Module):
+ """
+ Vision Rotary Position Embedding for SAM2, following transformers library standards.
+ Supports 2D (axial) rotary embeddings for spatial dimensions.
+ """
+
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__()
+ dim = config.memory_attention_hidden_size // (
+ config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
+ )
+ # Ensure even dimension for proper axial splitting
+ if dim % 4 != 0:
+ raise ValueError("Dimension must be divisible by 4 for axial RoPE")
+ end_x, end_y = config.memory_attention_rope_feat_sizes
+ freqs = 1.0 / (config.memory_attention_rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+
+ # Generate 2D position indices for axial rotary embedding
+ flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
+ x_positions = flattened_indices % end_x
+ y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
+ freqs_x = torch.outer(x_positions, freqs).float()
+ freqs_y = torch.outer(y_positions, freqs).float()
+ inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
+ inv_freq = inv_freq.repeat_interleave(2, dim=-1)
+ # directly register the cos and sin embeddings as we have a fixed feature shape
+ self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
+ self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
+
+ @torch.no_grad()
+ def forward(self) -> tuple[torch.Tensor, torch.Tensor]:
+ # As the feature map size is fixed, we can just return the pre-computed embeddings.
+ return self.rope_embeddings_cos, self.rope_embeddings_sin
+
+
+def rotate_pairwise(x):
+ """
+ pairwise rotation of the hidden dims of the input. Differerent from Llama Half-Tensor Rotation.
+
+ This is an optimized version of the following more explicit implementation:
+ ```python
+ x_rotated = torch.zeros_like(x, dtype=x.dtype, device=x.device)
+ x_rotated[..., ::2] = -x[..., 1::2]
+ x_rotated[..., 1::2] = x[..., ::2]
+ return x_rotated
+ ```
+ """
+ x = x.view(*x.shape[:-1], -1, 2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return x.flatten(start_dim=-2)
+
+
+# TODO: This leads to ~1e-07 max diff and ~1e-09 avg diff for q_embed and k_embed from the original implementation, most likely due to the use of complex tensors in the original implementation.
+def apply_rotary_pos_emb_2d(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ num_k_exclude_rope: int = 0,
+ repeat_freqs_k: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary position embedding to query and key tensors for vision models.
+ Follows the standard transformers library pattern.
+
+ Args:
+ q: Query tensor of shape (..., seq_len, head_dim)
+ k: Key tensor of shape (..., seq_len, head_dim)
+ cos: Cosine position embedding of shape (seq_len, head_dim)
+ sin: Sine position embedding of shape (seq_len, head_dim)
+ repeat_freqs_k: Whether to repeat frequencies for keys (for cross-attention)
+
+ Returns:
+ Rotated (q, k) tensors
+ """
+ k_rot, k_pass = k[..., : k.shape[-2] - num_k_exclude_rope, :], k[..., k.shape[-2] - num_k_exclude_rope :, :]
+ q_embed = q.float() # force upscale to float32 as in the original implementation
+ q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin)
+ if k_rot.shape[-2] == 0:
+ # Handle case where keys might be empty due to dropout
+ return q_embed.type_as(q), torch.cat([k_rot, k_pass], dim=-2)
+
+ # Handle key tensor - may need to repeat frequencies if different sequence length
+ if repeat_freqs_k and k_rot.shape[-2] != q.shape[-2]:
+ # Repeat cos/sin to match key sequence length
+ repeat_factor = k_rot.shape[-2] // q.shape[-2]
+ cos_k = cos.repeat(1, 1, repeat_factor, 1)
+ sin_k = sin.repeat(1, 1, repeat_factor, 1)
+ else:
+ cos_k = cos
+ sin_k = sin
+
+ # Apply rotary embedding to keys
+ k_embed = k_rot.float() # force upscale to float32 as in the original implementation
+ k_embed = (k_embed * cos_k) + (rotate_pairwise(k_embed) * sin_k)
+ # Concatenate back to full shape
+ k_embed = torch.cat([k_embed.type_as(k), k_pass], dim=-2)
+ return q_embed.type_as(q), k_embed
+
+
+class Sam2VideoRoPEAttention(nn.Module):
+ """Attention with rotary position encoding."""
+
+ def __init__(
+ self,
+ config: Sam2VideoConfig,
+ kv_in_dim: Optional[int] = None,
+ rope_k_repeat=False,
+ ):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.memory_attention_hidden_size
+ self.internal_dim = self.hidden_size // config.memory_attention_downsample_rate
+ self.num_attention_heads = config.memory_attention_num_attention_heads
+ self.head_dim = self.internal_dim // config.memory_attention_num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else self.hidden_size
+
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
+
+ self.rope_k_repeat = rope_k_repeat
+ self.dropout_p = config.memory_attention_rope_dropout
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ num_k_exclude_rope: int = 0,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tensor:
+ # Input projections
+ batch_size, point_batch_size = query.shape[:2]
+ new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
+
+ query = self.q_proj(query).view(*new_shape).transpose(1, 2)
+ key = self.k_proj(key).view(*new_shape).transpose(1, 2)
+ value = self.v_proj(value).view(*new_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ # Apply rotary position encoding, excluding some keys if specified
+ query, key = apply_rotary_pos_emb_2d(
+ query, key, cos, sin, repeat_freqs_k=self.rope_k_repeat, num_k_exclude_rope=num_k_exclude_rope
+ )
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=None,
+ dropout=0.0 if not self.training else self.dropout_p,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(
+ batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
+ ).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Sam2VideoMemoryAttentionLayer(nn.Module):
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__()
+ hidden_size = config.memory_attention_hidden_size
+ self.self_attn = Sam2VideoRoPEAttention(config)
+ self.cross_attn_image = Sam2VideoRoPEAttention(config, kv_in_dim=64, rope_k_repeat=True)
+
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(hidden_size, config.memory_attention_feed_forward_hidden_size)
+ self.dropout = nn.Dropout(config.memory_attention_dropout)
+ self.linear2 = nn.Linear(config.memory_attention_feed_forward_hidden_size, hidden_size)
+
+ self.layer_norm1 = nn.LayerNorm(hidden_size)
+ self.layer_norm2 = nn.LayerNorm(hidden_size)
+ self.layer_norm3 = nn.LayerNorm(hidden_size)
+ self.dropout1 = nn.Dropout(config.memory_attention_dropout)
+ self.dropout2 = nn.Dropout(config.memory_attention_dropout)
+ self.dropout3 = nn.Dropout(config.memory_attention_dropout)
+
+ self.activation = ACT2FN[config.memory_attention_feed_forward_hidden_act]
+
+ def forward(
+ self,
+ queries: Tensor,
+ keys: Tensor,
+ key_point_embedding: Tensor,
+ rope_position_embeddings: tuple[Tensor, Tensor],
+ num_k_exclude_rope: int = 0,
+ ) -> torch.Tensor:
+ # Self-Attention
+ query = self.layer_norm1(queries)
+ query, _ = self.self_attn(query=query, key=query, value=query, position_embeddings=rope_position_embeddings)
+ queries = queries + self.dropout1(query)
+
+ # Cross-Attention
+ query = self.layer_norm2(queries)
+ query, _ = self.cross_attn_image(
+ query=query,
+ key=keys + key_point_embedding,
+ value=keys,
+ position_embeddings=rope_position_embeddings,
+ num_k_exclude_rope=num_k_exclude_rope,
+ )
+ queries = queries + self.dropout2(query)
+ # MLP
+ query = self.layer_norm3(queries)
+ query = self.linear2(self.dropout(self.activation(self.linear1(query))))
+ queries = queries + self.dropout3(query)
+ return queries
+
+
+class Sam2VideoMemoryAttention(nn.Module):
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [Sam2VideoMemoryAttentionLayer(config) for _ in range(config.memory_attention_num_layers)]
+ )
+ self.layer_norm = nn.LayerNorm(config.memory_attention_hidden_size)
+ self.rotary_emb = Sam2VideoVisionRotaryEmbedding(config=config)
+
+ def forward(
+ self,
+ current_vision_features: torch.Tensor,
+ memory: torch.Tensor,
+ current_vision_position_embeddings: Optional[Tensor] = None,
+ memory_posision_embeddings: Optional[Tensor] = None,
+ num_object_pointer_tokens: int = 0,
+ ):
+ """
+ Args:
+ current_vision_features (`torch.FloatTensor`):
+ The current vision features used for self-attention.
+ memory (`torch.FloatTensor`):
+ The memory features used for cross-attention.
+ current_vision_position_embeddings (`torch.FloatTensor`, *optional*):
+ The position embeddings for the current vision features.
+ memory_posision_embeddings (`torch.FloatTensor`, *optional*):
+ The position embeddings for the memory features.
+ num_object_pointer_tokens (`int`, *optional*, defaults to 0):
+ The number of object pointer tokens.
+ """
+ output = current_vision_features
+ if current_vision_position_embeddings is not None:
+ output = output + 0.1 * current_vision_position_embeddings
+
+ # Convert to batch first
+ output = output.transpose(0, 1)
+ memory = memory.transpose(0, 1).unsqueeze(1)
+ memory_posision_embeddings = memory_posision_embeddings.transpose(0, 1).unsqueeze(1)
+ rope_position_embeddings = self.rotary_emb()
+ for layer in self.layers:
+ output = layer(
+ queries=output.unsqueeze(1) if output.ndim == 3 else output,
+ keys=memory,
+ key_point_embedding=memory_posision_embeddings,
+ rope_position_embeddings=rope_position_embeddings,
+ num_k_exclude_rope=num_object_pointer_tokens,
+ )
+
+ normed_output = self.layer_norm(output)
+
+ # Convert back to seq first
+ normed_output = normed_output.transpose(0, 1)
+
+ return normed_output
+
+
+# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
+class Sam2VideoMemoryFuserCXBlock(GradientCheckpointingLayer):
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__()
+ self.depthwise_conv = nn.Conv2d(
+ config.memory_fuser_embed_dim,
+ config.memory_fuser_embed_dim,
+ kernel_size=config.memory_fuser_kernel_size,
+ padding=config.memory_fuser_padding,
+ groups=config.memory_fuser_embed_dim,
+ ) # depthwise conv
+ self.layer_norm = Sam2VideoLayerNorm(config.memory_fuser_embed_dim, eps=1e-6, data_format="channels_first")
+ self.activation = ACT2FN[config.memory_fuser_hidden_act]
+ self.pointwise_conv1 = nn.Linear(
+ config.memory_fuser_embed_dim, config.memory_fuser_intermediate_dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.pointwise_conv2 = nn.Linear(config.memory_fuser_intermediate_dim, config.memory_fuser_embed_dim)
+ self.scale = nn.Parameter(
+ config.memory_fuser_layer_scale_init_value * torch.ones(config.memory_fuser_embed_dim),
+ requires_grad=True,
+ )
+
+ def forward(self, hidden_states):
+ input = hidden_states
+ hidden_states = self.depthwise_conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ hidden_states = self.pointwise_conv1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.pointwise_conv2(hidden_states)
+ hidden_states = self.scale * hidden_states
+ hidden_states = hidden_states.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ hidden_states = input + hidden_states
+ return hidden_states
+
+
+class Sam2VideoMemoryFuser(nn.Module):
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [Sam2VideoMemoryFuserCXBlock(config) for _ in range(config.memory_fuser_num_layers)]
+ )
+
+ def forward(self, hidden_states):
+ # normally hidden_states: (N, C, H, W)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+
+class Sam2VideoMaskDownSamplerLayer(nn.Module):
+ def __init__(self, config: Sam2VideoConfig, in_channels: int, out_channels: int):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=config.mask_downsampler_kernel_size,
+ stride=config.mask_downsampler_stride,
+ padding=config.mask_downsampler_padding,
+ )
+ self.layer_norm = Sam2VideoLayerNorm(out_channels, eps=1e-6, data_format="channels_first")
+ self.activation = ACT2FN[config.mask_downsampler_hidden_act]
+
+ def forward(self, x):
+ return self.activation(self.layer_norm(self.conv(x)))
+
+
+class Sam2VideoMaskDownSampler(nn.Module):
+ """
+ Progressively downsample a mask by total_stride, each time by stride.
+ Note that LayerNorm is applied per *token*, like in ViT.
+
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
+ In the end, we linearly project to embed_dim channels.
+ """
+
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__()
+
+ num_layers = int(math.log2(config.mask_downsampler_total_stride) // math.log2(config.mask_downsampler_stride))
+
+ self.layers = nn.ModuleList()
+ self.activation = ACT2FN[config.mask_downsampler_hidden_act]
+ mask_in_chans, mask_out_chans = 1, 1
+ for _ in range(num_layers):
+ mask_out_chans = mask_in_chans * (config.mask_downsampler_stride**2)
+ self.layers.append(Sam2VideoMaskDownSamplerLayer(config, mask_in_chans, mask_out_chans))
+ mask_in_chans = mask_out_chans
+
+ self.final_conv = nn.Conv2d(mask_out_chans, config.mask_downsampler_embed_dim, kernel_size=1)
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = layer(x)
+ x = self.final_conv(x)
+ return x
+
+
+class Sam2VideoMemoryEncoder(nn.Module):
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__()
+
+ hidden_size = config.memory_encoder_hidden_size
+ output_channels = config.memory_encoder_output_channels
+ self.mask_downsampler = Sam2VideoMaskDownSampler(config)
+ self.feature_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
+ self.memory_fuser = Sam2VideoMemoryFuser(config)
+ self.position_encoding = Sam2VideoPositionEmbeddingSine(num_pos_feats=output_channels // 2, normalize=True)
+ self.projection = nn.Conv2d(hidden_size, output_channels, kernel_size=1)
+
+ def forward(
+ self,
+ vision_features: torch.Tensor,
+ masks: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ ## Process masks
+ masks = self.mask_downsampler(masks)
+ ## Fuse pixel_features and downsampled masks
+
+ vision_features = self.feature_projection(vision_features)
+ vision_features = vision_features + masks
+ vision_features = self.memory_fuser(vision_features)
+ vision_features = self.projection(vision_features)
+
+ vision_pos_enc = self.position_encoding(vision_features.shape, vision_features.device, vision_features.dtype)
+
+ return vision_features, vision_pos_enc
+
+
+# a large negative value as a placeholder score for missing objects
+NO_OBJ_SCORE = -1024.0
+
+
+def get_1d_sine_pe(pos_inds, dim, temperature=10000):
+ """
+ Get 1D sine positional embedding as in the original Transformer paper.
+ """
+ pe_dim = dim // 2
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
+
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
+ return pos_embed
+
+
+@auto_docstring
+class Sam2VideoModel(Sam2Model):
+ _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
+ # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
+ _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
+ _keys_to_ignore_on_load_unexpected = []
+ _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2VideoTwoWayAttentionBlock, index=2)}
+
+ def __init__(self, config: Sam2VideoConfig):
+ super().__init__(config)
+ self.config = config
+ # For video sequence inference
+ self.image_size = config.image_size
+ self.memory_attention = Sam2VideoMemoryAttention(config)
+ self.memory_encoder = Sam2VideoMemoryEncoder(config)
+ self.no_memory_positional_encoding = torch.nn.Parameter(
+ torch.zeros(1, 1, config.vision_config.fpn_hidden_size)
+ )
+ self.mem_dim = config.memory_encoder_output_channels
+ self.num_maskmem = config.num_maskmem # Number of memories accessible
+ # Temporal encoding of the memories
+ self.memory_temporal_positional_encoding = torch.nn.Parameter(
+ torch.zeros(self.num_maskmem, 1, 1, self.mem_dim)
+ )
+
+ self.no_object_pointer = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
+ # A conv layer to downsample the mask prompt to stride 4 (the same stride as
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
+ # a feedforward layer on SAM output tokens to turn them into object pointers
+ self.object_pointer_proj = Sam2VideoFeedForward(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
+
+ if self.config.enable_temporal_pos_encoding_for_object_pointers:
+ # a linear projection on temporal positional encoding in object pointers to
+ # avoid potential interference with spatial positional encoding
+ self.temporal_positional_encoding_projection_layer = torch.nn.Linear(self.hidden_dim, self.mem_dim)
+ else:
+ self.temporal_positional_encoding_projection_layer = torch.nn.Identity()
+
+ self.occlusion_spatial_embedding_parameter = None # compatibility with Sam2
+ if config.enable_occlusion_spatial_embedding:
+ self.occlusion_spatial_embedding_parameter = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
+
+ self.post_init()
+
+ @torch.no_grad()
+ def get_prompt_embeddings(
+ self,
+ input_points: Optional[torch.FloatTensor] = None,
+ input_labels: Optional[torch.LongTensor] = None,
+ input_boxes: Optional[torch.FloatTensor] = None,
+ input_masks: Optional[torch.LongTensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
+
+ Args:
+ input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
+ Optional input points for the prompt encoder. The padding of the point is automatically done by the
+ processor. `point_batch_size` refers to the number of masks that we want the model to predict per
+ point. The model will output `point_batch_size` times 3 masks in total.
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
+ Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
+ processor, or can be fed by the user.
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
+ Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
+ processor. users can also pass manually the input boxes.
+ input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
+ Optional input masks for the prompt encoder.
+ """
+ prompt_output = self.prompt_encoder(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+ return prompt_output
+
+ def _prepare_vision_features(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ frame_idx: int,
+ batch_size: int,
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
+ """Prepare vision features for a frame."""
+
+ # Check if features are cached
+ if cached_features := inference_session.cache.get_vision_features(frame_idx):
+ vision_feats = cached_features["vision_feats"]
+ vision_pos_embeds = cached_features["vision_pos_embeds"]
+ else:
+ # Compute features using image encoder
+ image_batch = inference_session.get_frame(frame_idx).unsqueeze(0) # Add batch dimension
+ vision_feats, vision_pos_embeds, _, _ = self.get_image_features(image_batch)
+ # Cache features
+ inference_session.cache.cache_vision_features(
+ frame_idx, {"vision_feats": vision_feats, "vision_pos_embeds": vision_pos_embeds}
+ )
+
+ # Expand to batch size if needed
+ if batch_size > 1:
+ vision_feats = vision_feats.expand(batch_size, -1, -1, -1)
+ vision_pos_embeds = [pe.expand(batch_size, -1, -1, -1) for pe in vision_pos_embeds]
+
+ return vision_feats, vision_pos_embeds
+
+ def _single_frame_forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ input_points: Optional[torch.FloatTensor] = None,
+ input_labels: Optional[torch.LongTensor] = None,
+ input_boxes: Optional[torch.FloatTensor] = None,
+ input_masks: Optional[torch.LongTensor] = None,
+ image_embeddings: Optional[torch.FloatTensor] = None,
+ multimask_output: bool = True,
+ attention_similarity: Optional[torch.FloatTensor] = None,
+ target_embedding: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Sam2VideoImageSegmentationOutput:
+ """
+ input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
+ Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
+ better results. The points can be obtained by passing a list of list of list to the processor that will
+ create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
+ second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
+ per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
+ multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
+ coordinates of the point. If a different number of points is passed either for each image, or for each
+ mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
+ computation of the embedding will be skipped for these points using the labels.
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
+ Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
+ official implementation, there are 3 types of labels
+
+ - `1`: the point is a point that contains the object of interest
+ - `0`: the point is a point that does not contain the object of interest
+ - `-1`: the point corresponds to the background
+
+ We added the label:
+
+ - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
+
+ The padding labels should be automatically done by the processor.
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
+ Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
+ much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
+ that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
+ size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
+ In the order (`x1`, `y1`, `x2`, `y2`):
+
+ - `x1`: the x coordinate of the top left point of the input box
+ - `y1`: the y coordinate of the top left point of the input box
+ - `x2`: the x coordinate of the bottom right point of the input box
+ - `y2`: the y coordinate of the bottom right point of the input box
+ input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
+ SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
+ generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
+ manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
+ image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
+ Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
+ efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
+ method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
+ multimask_output (`bool`, *optional*):
+ In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
+ bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
+ "best" mask, by specifying `multimask_output=False`.
+ attention_similarity (`torch.FloatTensor`, *optional*):
+ Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
+ model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
+ target_embedding (`torch.FloatTensor`, *optional*):
+ Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
+ the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
+ """
+ if not ((pixel_values is None) ^ (image_embeddings is None)):
+ raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
+ if input_points is not None and input_boxes is not None:
+ if input_points.shape[1] != input_boxes.shape[1]:
+ raise ValueError(
+ f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
+ )
+ elif input_points is not None:
+ num_objects = input_points.shape[1]
+ elif input_boxes is not None:
+ num_objects = input_boxes.shape[1]
+ elif input_masks is not None:
+ num_objects = input_masks.shape[1]
+ else:
+ num_objects = 1
+
+ image_positional_embeddings = self.get_image_wide_positional_embeddings()
+ # repeat with batch size
+ batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
+ image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
+
+ vision_attentions = None
+ vision_hidden_states = None
+
+ if pixel_values is not None:
+ feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features(
+ pixel_values,
+ **kwargs,
+ )
+
+ # add no memory embedding to the last feature map
+ feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
+
+ # reshape feature maps to the same shape as the backbone feature sizes
+ image_embeddings = [
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+ for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
+ ]
+
+ if input_points is not None and input_labels is None:
+ input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
+
+ if input_points is None and input_boxes is None:
+ # If no points are provide, pad with an empty point (with label -1)
+ input_points = torch.zeros(
+ batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
+ )
+ input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
+
+ if input_masks is not None:
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
+ # and feed it as a dense mask prompt into the SAM mask encoder
+ if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
+ input_masks = F.interpolate(
+ input_masks.float(),
+ size=self.prompt_encoder.mask_input_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ ).to(input_masks.dtype)
+
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+ low_res_multimasks, iou_scores, sam_output_tokens, object_score_logits = self.mask_decoder(
+ image_embeddings=image_embeddings[-1],
+ image_positional_embeddings=image_positional_embeddings,
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ high_resolution_features=image_embeddings[:-1],
+ attention_similarity=attention_similarity,
+ target_embedding=target_embedding,
+ **kwargs,
+ )
+
+ is_obj_appearing = object_score_logits > 0
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
+ # consistent with the actual mask prediction
+ low_res_multimasks = torch.where(
+ is_obj_appearing[:, None, None],
+ low_res_multimasks,
+ NO_OBJ_SCORE,
+ )
+
+ # convert masks from possibly bfloat16 (or float16) to float32
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
+ high_res_multimasks = (
+ F.interpolate(
+ low_res_multimasks.squeeze(1).float(),
+ size=(self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ .unsqueeze(1)
+ .to(low_res_multimasks.dtype)
+ )
+ sam_output_token = sam_output_tokens[:, :, 0]
+ if multimask_output:
+ # take the best mask prediction (with the highest IoU estimation)
+ best_iou_inds = torch.argmax(iou_scores, dim=-1)
+ batch_inds = torch.arange(batch_size, device=high_res_multimasks.device)
+ object_batch_inds = torch.arange(num_objects, device=high_res_multimasks.device)
+ low_res_masks = low_res_multimasks[batch_inds, object_batch_inds, best_iou_inds]
+ high_res_masks = high_res_multimasks[batch_inds, object_batch_inds, best_iou_inds]
+ if sam_output_tokens.size(2) > 1:
+ sam_output_token = sam_output_tokens[batch_inds, object_batch_inds, best_iou_inds]
+ else:
+ low_res_masks, high_res_masks = low_res_multimasks[:, :, 0], high_res_multimasks[:, :, 0]
+
+ # Extract object pointer from the SAM output token (with occlusion handling)
+ object_pointer = self.object_pointer_proj(sam_output_token)
+ lambda_is_obj_appearing = is_obj_appearing.to(object_pointer.dtype)
+
+ object_pointer = lambda_is_obj_appearing * object_pointer
+ object_pointer = object_pointer + (1 - lambda_is_obj_appearing) * self.no_object_pointer
+
+ return Sam2VideoImageSegmentationOutput(
+ iou_scores=iou_scores,
+ pred_masks=low_res_masks,
+ high_res_masks=high_res_masks,
+ object_pointer=object_pointer,
+ object_score_logits=object_score_logits,
+ image_embeddings=image_embeddings,
+ vision_hidden_states=vision_hidden_states,
+ vision_attentions=vision_attentions,
+ )
+
+ def _use_mask_as_output(
+ self,
+ backbone_features: torch.Tensor,
+ high_res_features: list[torch.Tensor],
+ mask_inputs: torch.Tensor,
+ ) -> Sam2VideoImageSegmentationOutput:
+ """
+ Directly turn binary `mask_inputs` into a output mask logits without using SAM.
+ (same input and output shapes as in forward above).
+ """
+ # Use -10/+20 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
+ mask_inputs_float = mask_inputs.to(backbone_features[0].dtype)
+ high_res_masks = mask_inputs_float * out_scale + out_bias
+ low_res_masks = F.interpolate(
+ high_res_masks.float(),
+ size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ ).to(backbone_features[0].dtype)
+ # a dummy IoU prediction of all 1's under mask input
+ iou_scores = mask_inputs.new_ones(mask_inputs.size(0), 1).to(backbone_features[0].dtype)
+ # produce an object pointer using the SAM decoder from the mask input
+ object_pointer = self._single_frame_forward(
+ input_masks=self.mask_downsample(mask_inputs_float.to(backbone_features[0].dtype)),
+ image_embeddings=high_res_features + [backbone_features],
+ ).object_pointer
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
+ # on the object_scores from the SAM decoder.
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
+ is_obj_appearing = is_obj_appearing[..., None]
+ lambda_is_obj_appearing = is_obj_appearing.to(backbone_features[0].dtype)
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
+ object_pointer = lambda_is_obj_appearing * object_pointer
+ object_pointer = object_pointer + (1 - lambda_is_obj_appearing) * self.no_object_pointer
+ return Sam2VideoImageSegmentationOutput(
+ iou_scores=iou_scores,
+ pred_masks=low_res_masks,
+ high_res_masks=high_res_masks,
+ object_pointer=object_pointer,
+ object_score_logits=object_score_logits,
+ image_embeddings=high_res_features + [backbone_features],
+ )
+
+ def _gather_memory_frame_outputs(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ obj_idx: int,
+ frame_idx: int,
+ track_in_reverse_time: bool = False,
+ ) -> list[tuple[int, dict]]:
+ """
+ Get memory frames from conditioning and non-conditioning outputs.
+
+ Returns:
+ List of (relative_temporal_offset, output_data) tuples.
+ """
+ temporal_positions_and_previous_outputs = []
+
+ # Add conditioning frame outputs (no limit here, as is the case in the original checkpoints)
+ conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
+ if not conditioning_outputs:
+ raise ValueError(
+ "maskmem_features in conditioning outputs cannot be empty when not is_initial_conditioning_frame"
+ )
+
+ # Store (temporal_position, output_data) tuples
+ temporal_positions_and_previous_outputs = [(0, out) for out in conditioning_outputs.values()]
+
+ # Add non-conditioning memory frames (up to self.num_maskmem - 1)
+ # These are typically frames tracked by the model without direct user input.
+ # Frames are selected with a stride, prioritizing the most recent ones. Here we only support stride = 1 for simplicity.
+ for relative_temporal_offset in range(self.num_maskmem - 1, 0, -1):
+ # relative_temporal_offset: how many frames before (or after if reversing) the current frame
+ if not track_in_reverse_time:
+ previous_frame_idx = frame_idx - relative_temporal_offset
+ else:
+ previous_frame_idx = frame_idx + relative_temporal_offset
+
+ # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU
+ output_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get(
+ previous_frame_idx, None
+ )
+
+ temporal_positions_and_previous_outputs.append((relative_temporal_offset, output_data))
+
+ return temporal_positions_and_previous_outputs
+
+ def _build_memory_attention_inputs(
+ self,
+ temporal_positions_and_previous_outputs: list[tuple[int, dict]],
+ device: torch.device,
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
+ """
+ Concatenate memory features and positional embeddings from previous frames.
+
+ Returns:
+ Tuple of (memories_to_concatenate, memory_positional_embeddings_to_concatenate).
+ """
+ memories_to_concatenate = []
+ memory_positional_embeddings_to_concatenate = []
+
+ for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs:
+ if prev_output_data is None:
+ continue # Skip if no output data for this temporal position (e.g., padding frames)
+
+ # Load memory features (potentially from CPU to GPU)
+ # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels)
+ memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True)
+ memories_to_concatenate.append(memory_features)
+
+ # Spatial positional encoding (potentially from CPU to GPU)
+ spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True)
+
+ # Add temporal positional encoding
+ # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim)
+ combined_memory_pos_embed = (
+ spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1]
+ )
+ memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed)
+
+ return memories_to_concatenate, memory_positional_embeddings_to_concatenate
+
+ def _get_object_pointers(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ obj_idx: int,
+ frame_idx: int,
+ num_total_frames: int,
+ device: torch.device,
+ track_in_reverse_time: bool = False,
+ streaming: bool = False,
+ ) -> tuple[list[int], list[torch.Tensor], int]:
+ """
+ Get object pointers and their positional embeddings from past frames.
+
+ Returns:
+ Tuple of (temporal_offsets, pointer_tokens, max_object_pointers_to_use).
+ """
+ temporal_position_sign_multiplier = -1 if track_in_reverse_time else 1
+
+ # Determine max object pointers to use
+ if streaming:
+ max_object_pointers_to_use = self.config.max_object_pointers_in_encoder
+ else:
+ max_object_pointers_to_use = min(num_total_frames, self.config.max_object_pointers_in_encoder)
+
+ temporal_offsets: list[int] = []
+ pointer_tokens: list[torch.Tensor] = []
+
+ # Add object pointers from selected conditioning frames
+ # Optionally, only include pointers from past frames during evaluation
+ conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
+ eligible_conditioning_outputs = conditioning_outputs
+ if not self.training:
+ eligible_conditioning_outputs = {
+ temporal_idx: out
+ for temporal_idx, out in conditioning_outputs.items()
+ if (temporal_idx >= frame_idx if track_in_reverse_time else temporal_idx <= frame_idx)
+ }
+
+ for temporal_idx, out_data in eligible_conditioning_outputs.items():
+ temporal_difference = (frame_idx - temporal_idx) * temporal_position_sign_multiplier
+ temporal_offsets.append(temporal_difference)
+ pointer_tokens.append(out_data["object_pointer"].to(device))
+
+ # Add object pointers from non-conditioning frames (up to max_object_pointers_to_use - 1)
+ for t_diff_offset in range(1, max_object_pointers_to_use):
+ ref_frame_idx = frame_idx + t_diff_offset if track_in_reverse_time else frame_idx - t_diff_offset
+ if ref_frame_idx < 0 or (
+ not streaming and num_total_frames is not None and ref_frame_idx >= num_total_frames
+ ):
+ break # Stop if frame index is out of bounds
+
+ # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU
+ out_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get(
+ ref_frame_idx, None
+ )
+ if out_data is not None:
+ temporal_offsets.append(t_diff_offset)
+ pointer_tokens.append(out_data["object_pointer"].to(device))
+
+ return temporal_offsets, pointer_tokens, max_object_pointers_to_use
+
+ def _process_object_pointers(
+ self,
+ temporal_offsets: list[int],
+ pointer_tokens: list[torch.Tensor],
+ max_object_pointers_to_use: int,
+ batch_size: int,
+ num_channels: int,
+ device: torch.device,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Process object pointers and compute their positional embeddings.
+
+ Returns:
+ Tuple of (object_pointers, object_pointers_pos_embed).
+ """
+ if not pointer_tokens:
+ return None, None
+
+ # Stack object pointers: List of (Batch, Channels) -> (SeqLen_ptr, Batch, Channels)
+ object_pointers = torch.stack(pointer_tokens, dim=0)
+
+ if self.config.enable_temporal_pos_encoding_for_object_pointers:
+ max_temporal_diff = float(max_object_pointers_to_use - 1)
+ # Determine dimensionality for temporal positional encoding of pointers
+ pointer_tpos_dim = num_channels
+
+ # Normalize temporal differences before sine PE calculation
+ normalized_temporal_diffs = (
+ torch.tensor(temporal_offsets, device=device, dtype=torch.float32) / max_temporal_diff
+ )
+ sine_pe = get_1d_sine_pe(normalized_temporal_diffs, dim=pointer_tpos_dim).to(object_pointers.dtype)
+ projected_sine_pe = self.temporal_positional_encoding_projection_layer(sine_pe)
+ object_pointers_pos_embed = projected_sine_pe.unsqueeze(1).expand(-1, batch_size, self.mem_dim)
+ else:
+ object_pointers_pos_embed = object_pointers.new_zeros(
+ len(temporal_offsets), batch_size, self.mem_dim, dtype=object_pointers.dtype
+ )
+
+ if self.mem_dim < num_channels:
+ # If memory dimension is smaller, reshape/split pointers and repeat positional encoding
+ num_splits = num_channels // self.mem_dim
+ object_pointers = object_pointers.reshape(-1, batch_size, num_splits, self.mem_dim)
+ object_pointers = object_pointers.permute(0, 2, 1, 3).flatten(
+ 0, 1
+ ) # (SeqLen_ptr*num_splits, Batch, MemDim)
+ object_pointers_pos_embed = object_pointers_pos_embed.repeat_interleave(num_splits, dim=0)
+
+ return object_pointers, object_pointers_pos_embed
+
+ def _prepare_memory_conditioned_features(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ frame_idx: int,
+ obj_idx: int,
+ is_initial_conditioning_frame: bool,
+ current_vision_features: list[torch.Tensor],
+ current_vision_positional_embeddings: list[torch.Tensor],
+ num_total_frames: int,
+ track_in_reverse_time: bool = False,
+ streaming: bool = False,
+ ) -> torch.Tensor:
+ """
+ Fuse current frame's visual features with memory from previous frames for enhanced object tracking.
+
+ This method conditions the current frame's visual features on temporal memory from previous frames,
+ enabling consistent object tracking across video sequences. For initial conditioning frames, it uses
+ no-memory embeddings. For subsequent frames, it retrieves and integrates memory features from both
+ conditioning frames (user interactions) and non-conditioning frames (tracked results) via cross-attention.
+
+ Args:
+ inference_session (`Sam2VideoInferenceSession`):
+ The video inference session object.
+ frame_idx (`int`):
+ Index of the current frame being processed.
+ obj_idx (`int`):
+ Index of the object being processed.
+ is_initial_conditioning_frame (`bool`):
+ Whether this is an initial conditioning frame with user inputs (True) or a subsequent
+ tracking frame (False).
+ current_vision_features (`torch.Tensor`):
+ Highest-level vision features of shape `(seq_len, batch_size, channels)`.
+ current_vision_positional_embeddings (`torch.Tensor`):
+ Positional embedding tensors corresponding to the highest-level vision features.
+ num_total_frames (`int`):
+ Total number of frames in the video sequence.
+ track_in_reverse_time (`bool`, *optional*, defaults to `False`):
+ Whether tracking is performed in reverse temporal order.
+ streaming (`bool`, *optional*, defaults to `False`):
+ Whether this is streaming inference mode.
+
+ Returns:
+ `torch.Tensor`: Memory-conditioned feature tensor of shape `(batch_size, channels, height, width)`
+ suitable for input to the SAM decoder.
+ """
+ # Get dimensions from the highest-level (lowest-resolution) feature map
+ batch_size = current_vision_features.size(1)
+ num_channels = self.hidden_dim
+ height, width = self.backbone_feature_sizes[-1]
+ device = current_vision_features.device
+
+ # If memory is disabled (e.g., for single image SAM), return current features directly.
+ if self.num_maskmem == 0:
+ # Permute (SeqLen, Batch, Channels) -> (Batch, Channels, SeqLen) then view as (Batch, Channels, Height, Width)
+ # Assuming SeqLen = Height * Width for the last feature map
+ current_feature_map = current_vision_features.permute(1, 2, 0).view(
+ batch_size, num_channels, height, width
+ )
+ return current_feature_map
+
+ # Step 1: Handle initial conditioning frames
+ if is_initial_conditioning_frame:
+ # For initial conditioning frames, no prior memory is used directly in this block.
+ # If configured, directly add a learnable "no memory" embedding.
+ # current_vision_features has shape (SeqLen, Batch, Channels)
+ conditioned_feature_map_flat = current_vision_features + self.no_memory_embedding
+ # Reshape to (Batch, Channels, Height, Width)
+ conditioned_feature_map = conditioned_feature_map_flat.permute(1, 2, 0).view(
+ batch_size, num_channels, height, width
+ )
+ return conditioned_feature_map
+
+ # Step 2: Get memory frames and concatenate their features
+ temporal_positions_and_previous_outputs = self._gather_memory_frame_outputs(
+ inference_session, obj_idx, frame_idx, track_in_reverse_time
+ )
+
+ memories_to_concatenate, memory_positional_embeddings_to_concatenate = self._build_memory_attention_inputs(
+ temporal_positions_and_previous_outputs, device
+ )
+
+ # Step 3: Get and process object pointers
+ temporal_offsets, pointer_tokens, max_object_pointers_to_use = self._get_object_pointers(
+ inference_session, obj_idx, frame_idx, num_total_frames, device, track_in_reverse_time, streaming
+ )
+
+ num_object_pointer_tokens = 0
+ if pointer_tokens:
+ object_pointers, object_pointers_pos_embed = self._process_object_pointers(
+ temporal_offsets, pointer_tokens, max_object_pointers_to_use, batch_size, num_channels, device
+ )
+
+ if object_pointers is not None:
+ memories_to_concatenate.append(object_pointers)
+ memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed)
+ num_object_pointer_tokens = object_pointers.shape[0]
+
+ # Step 4: Concatenate all retrieved memories and their positional embeddings
+ combined_memory = torch.cat(memories_to_concatenate, dim=0)
+ combined_memory_positional_embeddings = torch.cat(memory_positional_embeddings_to_concatenate, dim=0)
+
+ # Step 5: Forward through the memory attention mechanism
+ conditioned_feature_map_flat = self.memory_attention(
+ current_vision_features=current_vision_features,
+ current_vision_position_embeddings=current_vision_positional_embeddings,
+ memory=combined_memory,
+ memory_posision_embeddings=combined_memory_positional_embeddings, # Corrected typo from API
+ num_object_pointer_tokens=num_object_pointer_tokens,
+ )
+
+ # Reshape from (Batch, H*W, Channels) to (Batch, Channels, Height, Width)
+ conditioned_feature_map = (
+ conditioned_feature_map_flat.squeeze(1).permute(0, 2, 1).view(batch_size, num_channels, height, width)
+ )
+ return conditioned_feature_map
+
+ def _use_multimask(self, is_init_cond_frame: bool, point_inputs: Optional[dict]) -> bool:
+ """Whether to use multimask output in the SAM head."""
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(2)
+ multimask_output = (
+ self.config.multimask_output_in_sam
+ and (is_init_cond_frame or self.config.multimask_output_for_tracking)
+ and (self.config.multimask_min_pt_num <= num_pts <= self.config.multimask_max_pt_num)
+ )
+ return multimask_output
+
+ def _run_single_frame_inference(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ frame_idx: int,
+ obj_idx: int,
+ batch_size: int,
+ is_init_cond_frame: bool,
+ point_inputs: Optional[torch.Tensor],
+ mask_inputs: Optional[torch.Tensor],
+ reverse: bool,
+ run_mem_encoder: bool,
+ prev_sam_mask_logits: Optional[torch.Tensor] = None,
+ streaming: bool = False,
+ ) -> dict[str, Any]:
+ """
+ Perform a single tracking step for video object segmentation.
+
+ Args:
+ inference_session (`Sam2VideoInferenceSession`):
+ The video inference session object.
+ frame_idx (`int`):
+ Index of the current frame.
+ obj_idx (`int`):
+ Index of the current object.
+ batch_size (`int`):
+ Batch size of the current frame.
+ is_init_cond_frame (`bool`):
+ Whether this is an initial conditioning frame with user inputs.
+ point_inputs (`dict`, *optional*):
+ Point prompt inputs for the current frame.
+ mask_inputs (`torch.Tensor`, *optional*):
+ Mask prompt inputs for the current frame.
+ reverse (`bool`, *optional*, defaults to `False`):
+ Whether to track in reverse time order.
+ run_mem_encoder (`bool`, *optional*, defaults to `True`):
+ Whether to run the memory encoder on predicted masks.
+ prev_sam_mask_logits (`torch.Tensor`, *optional*):
+ Previously predicted SAM mask logits that can be fed with new clicks.
+ streaming (`bool`, *optional*, defaults to `False`):
+ Whether this is streaming inference.
+
+ Returns:
+ `dict`: Dictionary containing the tracking results for the current frame, including:
+ - pred_masks: Predicted low-resolution masks.
+ - object_pointer: Object pointer for memory.
+ - object_score_logits: Object score logits (inference only).
+ - maskmem_features: Memory features for future frames.
+ - maskmem_pos_enc: Memory positional encodings.
+ """
+ # Retrieve correct image features
+ current_vision_feats, current_vision_pos_embeds = self._prepare_vision_features(
+ inference_session, frame_idx, batch_size
+ )
+ # point and mask should not appear as input simultaneously on the same frame
+ if point_inputs is not None and mask_inputs is not None:
+ raise ValueError(
+ "point_inputs and mask_inputs should not appear as input simultaneously on the same frame"
+ )
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
+ if len(current_vision_feats) > 1:
+ high_res_features = [
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
+ for x, s in zip(current_vision_feats[:-1], self.backbone_feature_sizes[:-1])
+ ]
+ else:
+ high_res_features = None
+ if mask_inputs is not None:
+ # We directly output the mask input (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *self.backbone_feature_sizes[-1])
+ sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
+ else:
+ # fused the visual feature with previous memory features in the memory bank
+ pix_feat = self._prepare_memory_conditioned_features(
+ inference_session=inference_session,
+ frame_idx=frame_idx,
+ obj_idx=obj_idx,
+ is_initial_conditioning_frame=is_init_cond_frame,
+ current_vision_features=current_vision_feats[-1],
+ current_vision_positional_embeddings=current_vision_pos_embeds[-1],
+ num_total_frames=inference_session.num_frames,
+ track_in_reverse_time=reverse,
+ streaming=streaming,
+ )
+ # apply SAM-style segmentation head
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
+ if prev_sam_mask_logits is not None:
+ mask_inputs = prev_sam_mask_logits
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
+ sam_outputs = self._single_frame_forward(
+ pixel_values=None, # Vision features already computed
+ input_points=point_inputs["point_coords"] if point_inputs is not None else None,
+ input_labels=point_inputs["point_labels"] if point_inputs is not None else None,
+ input_masks=mask_inputs,
+ image_embeddings=high_res_features + [pix_feat],
+ multimask_output=multimask_output,
+ )
+
+ # Finally run the memory encoder on the predicted mask to encode
+ # it into a new memory feature (which will be used to condition vision features in future frames)
+ maskmem_features = None
+ maskmem_pos_enc = None
+ if run_mem_encoder and self.num_maskmem > 0:
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+ current_vision_feats=current_vision_feats[-1],
+ pred_masks_high_res=sam_outputs.high_res_masks,
+ object_score_logits=sam_outputs.object_score_logits,
+ is_mask_from_pts=(point_inputs is not None or mask_inputs is not None),
+ )
+
+ current_out = {
+ "pred_masks": sam_outputs.pred_masks,
+ "object_pointer": sam_outputs.object_pointer,
+ "maskmem_features": maskmem_features if maskmem_features is not None else None,
+ "maskmem_pos_enc": maskmem_pos_enc,
+ }
+ if not self.training:
+ current_out["object_score_logits"] = sam_outputs.object_score_logits
+
+ return current_out
+
+ def _encode_new_memory(
+ self,
+ current_vision_feats: torch.Tensor,
+ pred_masks_high_res: torch.Tensor,
+ object_score_logits: torch.Tensor,
+ is_mask_from_pts: bool,
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
+ """Encode the current image and its prediction into a memory feature."""
+ batch_size = current_vision_feats.size(1) # batch size on this frame
+ channels = self.hidden_dim
+ height, width = self.backbone_feature_sizes[-1] # top-level (lowest-resolution) feature size
+ # top-level feature, (HW)BC => BCHW
+ pix_feat = current_vision_feats.permute(1, 2, 0).view(batch_size, channels, height, width)
+ if is_mask_from_pts and not self.training:
+ # binarize the mask logits
+ mask_for_mem = (pred_masks_high_res > 0).to(pred_masks_high_res.dtype)
+ else:
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
+ # apply scale and bias terms to the sigmoid probabilities
+ mask_for_mem = mask_for_mem * self.config.sigmoid_scale_for_mem_enc
+ mask_for_mem = mask_for_mem + self.config.sigmoid_bias_for_mem_enc
+
+ maskmem_features, maskmem_pos_enc = self.memory_encoder(
+ pix_feat,
+ mask_for_mem,
+ )
+ # add a no-object embedding to the spatial memory to indicate that the frame
+ # is predicted to be occluded (i.e. no object is appearing in the frame)
+ if self.occlusion_spatial_embedding_parameter is not None:
+ is_obj_appearing = (object_score_logits > 0).float()
+ maskmem_features += (1 - is_obj_appearing[..., None]) * self.occlusion_spatial_embedding_parameter[
+ ..., None, None
+ ].expand(*maskmem_features.shape)
+
+ # convert to bfloat16 to save memory, and for consistency with the original implementation
+ maskmem_features = maskmem_features.to(torch.bfloat16).flatten(2).permute(2, 0, 1)
+ maskmem_pos_enc = maskmem_pos_enc.to(pred_masks_high_res.dtype).flatten(2).permute(2, 0, 1)
+
+ return maskmem_features, maskmem_pos_enc
+
+ @torch.inference_mode()
+ @auto_docstring(custom_intro="Propagate the objects through a streamed video frame.")
+ def forward(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ frame_idx: Optional[int] = None,
+ frame: Optional[torch.Tensor] = None,
+ reverse: bool = False,
+ ) -> Sam2VideoSegmentationOutput:
+ r"""
+ inference_session (`Sam2VideoInferenceSession`):
+ The video inference session object.
+ frame_idx (`int`, *optional*):
+ The index of the frame on which to run inference. No need to provide when inferring
+ on a new streamed frame.
+ frame (`torch.Tensor`, *optional*):
+ The frame to process. Provide when streaming.
+ reverse (`bool`, *optional*, defaults to `False`):
+ Whether to propagate in reverse.
+ """
+ if frame is not None:
+ frame_idx = inference_session.add_new_frame(frame, frame_idx)
+
+ if frame is not None and inference_session.get_obj_num() == 0:
+ raise ValueError("No objects are provided for tracking; please add inputs first.")
+
+ num_objects = inference_session.get_obj_num()
+ pred_masks_per_obj = [None] * num_objects
+ # Note: We avoid batched inference here because per-object inputs (clicks/masks)
+ # can differ across objects.
+ for obj_idx in range(num_objects):
+ obj_id = inference_session.obj_idx_to_id(obj_idx)
+ has_new_inputs = obj_id in inference_session.obj_with_new_inputs
+ has_cond_output = frame_idx in inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
+ # If this object has no new inputs and this frame already has a
+ # conditioning output, reuse the cached masks instead of recomputing.
+ if (not has_new_inputs) and has_cond_output:
+ pred_masks = inference_session.get_output(obj_idx, frame_idx, "pred_masks", is_conditioning_frame=True)
+ is_init_cond_frame = True
+ else:
+ # Defaults when there are no new inputs
+ is_init_cond_frame = False
+ point_inputs = None
+ mask_inputs = None
+
+ if has_new_inputs:
+ is_init_cond_frame = frame_idx not in inference_session.frames_tracked_per_obj[obj_idx]
+ if is_init_cond_frame:
+ reverse = False
+ point_inputs = inference_session.point_inputs_per_obj[obj_idx].get(frame_idx, None)
+ mask_inputs = inference_session.mask_inputs_per_obj[obj_idx].get(frame_idx, None)
+ if point_inputs is not None or mask_inputs is not None:
+ inference_session.obj_with_new_inputs.remove(obj_id)
+
+ current_out = self._run_single_frame_inference(
+ inference_session=inference_session,
+ obj_idx=obj_idx,
+ frame_idx=frame_idx,
+ batch_size=1, # run on the slice of a single object
+ is_init_cond_frame=is_init_cond_frame,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ reverse=reverse,
+ run_mem_encoder=True,
+ streaming=frame is not None,
+ )
+ inference_session.store_output(
+ obj_idx, frame_idx, output_value=current_out, is_conditioning_frame=is_init_cond_frame
+ )
+ pred_masks = current_out["pred_masks"]
+
+ pred_masks_per_obj[obj_idx] = pred_masks
+ if not is_init_cond_frame:
+ # only for tracked frames, not for initial conditioning frames
+ inference_session.frames_tracked_per_obj[obj_idx][frame_idx] = {"reverse": reverse}
+
+ # Resize the output mask to the original video resolution (we directly use
+ # the mask scores on GPU for output to avoid any CPU conversion in between)
+ if len(pred_masks_per_obj) > 1:
+ all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
+ else:
+ all_pred_masks = pred_masks_per_obj[0]
+
+ return Sam2VideoSegmentationOutput(pred_masks=all_pred_masks, frame_idx=frame_idx)
+
+ @torch.inference_mode()
+ @auto_docstring(
+ custom_intro="""
+ Propagate the objects through the video frames. Used when initializing an inference session with a whole video.
+ Yields Sam2VideoSegmentationOutput for each frame.
+ """
+ )
+ def propagate_in_video_iterator(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ start_frame_idx: Optional[int] = None,
+ max_frame_num_to_track: Optional[int] = None,
+ reverse: bool = False,
+ ) -> Iterator[Sam2VideoSegmentationOutput]:
+ r"""
+ inference_session (`Sam2VideoInferenceSession`):
+ The video inference session object.
+ start_frame_idx (`int`, *optional*):
+ The starting frame index for propagation.
+ Need to be provided if `forward` hasn't been called on new inputs yet.
+ If not provided, the starting frame index will be the earliest frame with input points.
+ max_frame_num_to_track (`int`, *optional*):
+ The maximum number of frames to track.
+ reverse (`bool`, *optional*, defaults to `False`):
+ Whether to propagate in reverse.
+ """
+ num_frames = inference_session.num_frames
+
+ # set start index, end index, and processing order
+ if start_frame_idx is None:
+ # default: start from the earliest frame with input points
+ frames_with_inputs = [
+ frame_idx
+ for obj_output_dict in inference_session.output_dict_per_obj.values()
+ for frame_idx in obj_output_dict["cond_frame_outputs"]
+ ]
+ if not frames_with_inputs:
+ raise ValueError(
+ "Cannot determine the starting frame index; please specify it manually, or run inference on a frame with inputs first."
+ )
+ start_frame_idx = min(frames_with_inputs)
+ if max_frame_num_to_track is None:
+ # default: track all the frames in the video
+ max_frame_num_to_track = num_frames
+ if reverse:
+ end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
+ if start_frame_idx > 0:
+ processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
+ else:
+ processing_order = [] # skip reverse tracking if starting from frame 0
+ else:
+ end_frame_idx = min(start_frame_idx + max_frame_num_to_track, num_frames - 1)
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
+
+ for frame_idx in tqdm(processing_order, desc="propagate in video"):
+ sam2_video_output = self(inference_session, frame_idx=frame_idx, reverse=reverse)
+ yield sam2_video_output
+
+
+__all__ = [
+ "Sam2VideoModel",
+ "Sam2VideoInferenceSession",
+ "Sam2VideoPreTrainedModel",
+ "Sam2VideoMaskDecoderConfig",
+ "Sam2VideoPromptEncoderConfig",
+ "Sam2VideoProcessor",
+ "Sam2VideoConfig",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe5b732b951367cd130d7edd3299eda74f9dc267
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_siglip2 import *
+ from .image_processing_siglip2 import *
+ from .image_processing_siglip2_fast import *
+ from .modeling_siglip2 import *
+ from .processing_siglip2 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/configuration_siglip2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/configuration_siglip2.py
new file mode 100644
index 0000000000000000000000000000000000000000..67ef9df8f4f8f03608d7ef5c09b269c4cc74433b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/configuration_siglip2.py
@@ -0,0 +1,265 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_siglip2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Siglip2TextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Siglip2TextModel`]. It is used to instantiate a
+ Siglip2 text encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip2
+ [google/siglip2-base-patch16-224](https://huggingface.co/google/siglip2-base-patch16-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the Siglip2 text model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`Siglip2Model`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ max_position_embeddings (`int`, *optional*, defaults to 64):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ pad_token_id (`int`, *optional*, defaults to 1):
+ The id of the padding token in the vocabulary.
+ bos_token_id (`int`, *optional*, defaults to 49406):
+ The id of the beginning-of-sequence token in the vocabulary.
+ eos_token_id (`int`, *optional*, defaults to 49407):
+ The id of the end-of-sequence token in the vocabulary.
+ projection_size (`int`, *optional*, defaults to `hidden_size`):
+ The size of the projection head.
+
+ Example:
+
+ ```python
+ >>> from transformers import Siglip2TextConfig, Siglip2TextModel
+
+ >>> # Initializing a Siglip2TextConfig with google/siglip2-base-patch16-224 style configuration
+ >>> configuration = Siglip2TextConfig()
+
+ >>> # Initializing a Siglip2TextModel (with random weights) from the google/siglip2-base-patch16-224 style configuration
+ >>> model = Siglip2TextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "siglip2_text_model"
+ base_config_key = "text_config"
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=768,
+ intermediate_size=3072,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ max_position_embeddings=64,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ # This differs from `CLIPTokenizer`'s default and from openai/siglip2
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
+ pad_token_id=1,
+ bos_token_id=49406,
+ eos_token_id=49407,
+ projection_size=None,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.attention_dropout = attention_dropout
+ self.projection_size = projection_size if projection_size is not None else hidden_size
+
+
+class Siglip2VisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a
+ Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2
+ [google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input images.
+ num_patches (`int`, *optional*, defaults to 256):
+ The number of patches in the image with the size of (`patch_size`, `patch_size`).
+ The image is resized to fill maximum of this number of patches, and to preserve
+ the aspect ratio. In case the resulted number of patches is lower, the image is
+ padded in "patch" dimension.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+
+ Example:
+
+ ```python
+ >>> from transformers import Siglip2VisionConfig, Siglip2VisionModel
+
+ >>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration
+ >>> configuration = Siglip2VisionConfig()
+
+ >>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration
+ >>> model = Siglip2VisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "siglip2_vision_model"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ intermediate_size=3072,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ num_channels=3,
+ num_patches=256,
+ patch_size=16,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.num_patches = num_patches
+
+
+class Siglip2Config(PretrainedConfig):
+ r"""
+ [`Siglip2Config`] is the configuration class to store the configuration of a [`Siglip2Model`]. It is used to
+ instantiate a Siglip2 model according to the specified arguments, defining the text model and vision model configs.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip2
+ [google/siglip2-base-patch16-224](https://huggingface.co/google/siglip2-base-patch16-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`Siglip2TextConfig`].
+ vision_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`Siglip2VisionConfig`].
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+
+ Example:
+
+ ```python
+ >>> from transformers import Siglip2Config, Siglip2Model
+
+ >>> # Initializing a Siglip2Config with google/siglip2-base-patch16-224 style configuration
+ >>> configuration = Siglip2Config()
+
+ >>> # Initializing a Siglip2Model (with random weights) from the google/siglip2-base-patch16-224 style configuration
+ >>> model = Siglip2Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+
+ >>> # We can also initialize a Siglip2Config from a Siglip2TextConfig and a Siglip2VisionConfig
+ >>> from transformers import Siglip2TextConfig, Siglip2VisionConfig
+
+ >>> # Initializing a Siglip2Text and Siglip2Vision configuration
+ >>> config_text = Siglip2TextConfig()
+ >>> config_vision = Siglip2VisionConfig()
+
+ >>> config = Siglip2Config.from_text_vision_configs(config_text, config_vision)
+ ```"""
+
+ model_type = "siglip2"
+ sub_configs = {"text_config": Siglip2TextConfig, "vision_config": Siglip2VisionConfig}
+
+ def __init__(self, text_config=None, vision_config=None, **kwargs):
+ super().__init__(**kwargs)
+
+ if text_config is None:
+ text_config = {}
+ logger.info("`text_config` is `None`. Initializing the `Siglip2TextConfig` with default values.")
+
+ if vision_config is None:
+ vision_config = {}
+ logger.info("`vision_config` is `None`. initializing the `Siglip2VisionConfig` with default values.")
+
+ self.text_config = Siglip2TextConfig(**text_config)
+ self.vision_config = Siglip2VisionConfig(**vision_config)
+
+ self.initializer_factor = 1.0
+
+
+__all__ = ["Siglip2Config", "Siglip2TextConfig", "Siglip2VisionConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/image_processing_siglip2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/image_processing_siglip2.py
new file mode 100644
index 0000000000000000000000000000000000000000..30b5f1b958af2dadd2c7dcc79619f418857086fc
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/image_processing_siglip2.py
@@ -0,0 +1,344 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for SigLIP2."""
+
+import math
+from functools import lru_cache
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature
+from ...image_transforms import (
+ convert_to_rgb,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_vision_available():
+ from PIL import Image
+
+
+@lru_cache(maxsize=256)
+def get_image_size_for_max_num_patches(
+ image_height: int, image_width: int, patch_size: int, max_num_patches: int, eps: float = 1e-5
+) -> tuple[int, int]:
+ """
+ Determine image size based on max number of patches, ensure dimensions are divisible by patch size and image is at least 1 patch.
+
+ Args:
+ image_height (`int`):
+ Original image height.
+ image_width (`int`):
+ Original image width.
+ patch_size (`int`):
+ Patch size for processing.
+ max_num_patches (`int`):
+ Maximum number of patches.
+ eps (`float`):
+ Small threshold for binary search.
+
+ Returns:
+ Tuple: (target_height, target_width)
+ """
+
+ def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int:
+ scaled_size = size * scale
+ scaled_size = math.ceil(scaled_size / patch_size) * patch_size # make divisible by patch_size
+ scaled_size = max(patch_size, scaled_size) # ensure at least 1 patch
+ return int(scaled_size)
+
+ # Binary search for optimal scale
+ scale_min, scale_max = eps / 10, 100.0
+ while (scale_max - scale_min) >= eps:
+ scale = (scale_min + scale_max) / 2
+ target_height = get_scaled_image_size(scale, image_height, patch_size)
+ target_width = get_scaled_image_size(scale, image_width, patch_size)
+ num_patches = (target_height / patch_size) * (target_width / patch_size)
+
+ if num_patches <= max_num_patches:
+ scale_min = scale
+ else:
+ scale_max = scale
+
+ scale = scale_min
+ target_height = get_scaled_image_size(scale, image_height, patch_size)
+ target_width = get_scaled_image_size(scale, image_width, patch_size)
+ return target_height, target_width
+
+
+def convert_image_to_patches(image: np.ndarray, patch_size: int) -> np.ndarray:
+ """
+ Convert 3D array image of shape (image_height, image_width, num_channels) into 2D array of patches of shape
+ (num_patches_height * num_patches_width, patch_size * patch_size * num_channels).
+ """
+ image_height, image_width, num_channels = image.shape
+ num_patches_height = image_height // patch_size
+ num_patches_width = image_width // patch_size
+ patched_image = image.reshape(num_patches_height, patch_size, num_patches_width, patch_size, num_channels)
+ patched_image = patched_image.transpose(0, 2, 1, 3, 4)
+ patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1)
+ return patched_image
+
+
+def pad_along_first_dim(array: np.ndarray, target_length: int, pad_value: int = 0) -> tuple[np.ndarray, np.ndarray]:
+ """
+ Pad the array along the first dimension.
+ """
+ current_length = array.shape[0]
+ padding_length = target_length - current_length
+ mask = np.ones((target_length,), dtype=np.int32)
+ if padding_length > 0:
+ paddings = [(0, padding_length)] + [(0, 0)] * (array.ndim - 1)
+ array = np.pad(array, paddings, mode="constant", constant_values=pad_value)
+ mask[-padding_length:] = 0
+ return array, mask
+
+
+class Siglip2ImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a SigLIP2 image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's dimensions to fit `max_num_patches` according to given `patch_size`.
+ Can be overridden by `do_resize` in the `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
+ `do_normalize` in the `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch the image will be split to.
+ max_num_patches (`int`, *optional*, defaults to 256):
+ The image will be resized to have at most this number of patches,
+ and then padded in "patch" dimension to match this number exactly.
+ """
+
+ model_input_names = ["pixel_values", "pixel_attention_mask", "spatial_shapes"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ resample: "PILImageResampling" = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ patch_size: int = 16,
+ max_num_patches: int = 256,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
+ image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
+
+ self.do_resize = do_resize
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.do_convert_rgb = do_convert_rgb
+ self.patch_size = patch_size
+ self.max_num_patches = max_num_patches
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ resample: Optional["PILImageResampling"] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ patch_size: Optional[int] = None,
+ max_num_patches: Optional[int] = None,
+ ) -> "Image.Image":
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
+ Patch size for processing, same as the patch size used in the model.
+ max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`):
+ Maximum number of patches per image, the image will be resized to have at most this number of patches.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches
+
+ # Explicitly specify data format to be channels last for image preprocessing.
+ # Image processor does not support different output formats, because it returns patches.
+ data_format = ChannelDimension.LAST
+
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ )
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ pixel_masks = []
+ pixel_values = []
+ spatial_shapes = []
+
+ for image in images:
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+
+ if do_resize:
+ height, width = get_image_size_for_max_num_patches(
+ image_height=image.shape[0],
+ image_width=image.shape[1],
+ patch_size=patch_size,
+ max_num_patches=max_num_patches,
+ )
+ image = resize(image=image, size=(height, width), resample=resample, input_data_format=data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=data_format)
+
+ if do_normalize:
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=data_format)
+
+ patches = convert_image_to_patches(image, patch_size)
+ patches, mask = pad_along_first_dim(patches, max_num_patches)
+ num_patches_height = image.shape[0] // patch_size
+ num_patches_width = image.shape[1] // patch_size
+
+ spatial_shapes.append((num_patches_height, num_patches_width))
+ pixel_values.append(patches)
+ pixel_masks.append(mask)
+
+ batch_feature = BatchFeature(
+ data={
+ "pixel_values": pixel_values,
+ "pixel_attention_mask": pixel_masks,
+ "spatial_shapes": spatial_shapes,
+ },
+ tensor_type=return_tensors,
+ )
+
+ return batch_feature
+
+
+__all__ = ["Siglip2ImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/image_processing_siglip2_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/image_processing_siglip2_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..45261fab2cd0e7195d05bc0d2a8dafc71b6d2e19
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/image_processing_siglip2_fast.py
@@ -0,0 +1,170 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for SigLIP2."""
+
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ SizeDict,
+)
+from ...image_utils import (
+ ImageInput,
+ PILImageResampling,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ auto_docstring,
+ logging,
+)
+from .image_processing_siglip2 import get_image_size_for_max_num_patches
+
+
+logger = logging.get_logger(__name__)
+
+
+def convert_image_to_patches(image: "torch.Tensor", patch_size: int) -> "torch.Tensor":
+ """
+ Convert 3D tensor image of shape (num_channels, image_height, image_width) into 2D tensor of patches of shape
+ (num_patches_height * num_patches_width, patch_size * patch_size * num_channels).
+ """
+ num_channels, image_height, image_width = image.shape
+ num_patches_height = image_height // patch_size
+ num_patches_width = image_width // patch_size
+ patched_image = image.reshape(num_channels, num_patches_height, patch_size, num_patches_width, patch_size)
+ patched_image = patched_image.permute(1, 3, 2, 4, 0)
+ patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1)
+ return patched_image
+
+
+def pad_along_first_dim(
+ tensor: "torch.Tensor", target_length: int, pad_value: int = 0
+) -> tuple["torch.Tensor", "torch.Tensor"]:
+ """
+ Pad the tensor along the first dimension.
+ """
+ current_length = tensor.shape[0]
+ padding_length = target_length - current_length
+ mask = torch.ones((target_length,), dtype=torch.int32)
+ if padding_length > 0:
+ padding = [0, 0] * (tensor.ndim - 1) + [0, padding_length]
+ tensor = torch.nn.functional.pad(tensor, padding, mode="constant", value=pad_value)
+ mask[-padding_length:] = 0
+ return tensor, mask
+
+
+class Siglip2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch the image will be split to.
+ max_num_patches (`int`, *optional*, defaults to 256):
+ The image will be resized to have at most this number of patches,
+ and then padded in "patch" dimension to match this number exactly.
+ """
+
+ patch_size: Optional[int]
+ max_num_patches: Optional[int]
+
+
+@auto_docstring
+class Siglip2ImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = [0.5, 0.5, 0.5]
+ image_std = [0.5, 0.5, 0.5]
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ patch_size = 16
+ max_num_patches = 256
+ valid_kwargs = Siglip2FastImageProcessorKwargs
+ unused_kwargs = ["size", "do_center_crop", "crop_size"]
+
+ def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ def _validate_preprocess_kwargs(self, **kwargs) -> tuple:
+ # Remove do_resize from kwargs to not raise an error as size is None
+ kwargs.pop("do_resize", None)
+ return super()._validate_preprocess_kwargs(**kwargs)
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ patch_size: int,
+ max_num_patches: int,
+ interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ pixel_masks = []
+ pixel_values = []
+ spatial_shapes = []
+
+ for image in images:
+ if do_resize:
+ height, width = get_image_size_for_max_num_patches(
+ image_height=image.shape[1],
+ image_width=image.shape[2],
+ patch_size=patch_size,
+ max_num_patches=max_num_patches,
+ )
+ side_dict = SizeDict(height=height, width=width)
+ image = self.resize(image=image, size=side_dict, interpolation=interpolation)
+
+ image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
+
+ # (num_channels, height, width) -> (num_patches, patch_size * patch_size * num_channels)
+ patches = convert_image_to_patches(image, patch_size)
+ patches, mask = pad_along_first_dim(patches, max_num_patches)
+
+ num_patches_height = image.shape[1] // patch_size
+ num_patches_width = image.shape[2] // patch_size
+
+ spatial_shapes.append((num_patches_height, num_patches_width))
+ pixel_values.append(patches)
+ pixel_masks.append(mask)
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_masks = torch.stack(pixel_masks)
+ spatial_shapes = torch.tensor(spatial_shapes)
+
+ batch_feature = BatchFeature(
+ data={
+ "pixel_values": pixel_values,
+ "pixel_attention_mask": pixel_masks,
+ "spatial_shapes": spatial_shapes,
+ },
+ tensor_type=return_tensors,
+ )
+ return batch_feature
+
+
+__all__ = ["Siglip2ImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/modeling_siglip2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/modeling_siglip2.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae34e4d8f61c5d0a397009fc0c8896c9a5634746
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/modeling_siglip2.py
@@ -0,0 +1,1195 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_siglip2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+import warnings
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.init import _calculate_fan_in_and_fan_out
+
+from ...activations import ACT2FN
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs
+from ...utils.generic import check_model_inputs
+from .configuration_siglip2 import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
+ """
+)
+class Siglip2VisionOutput(ModelOutput):
+ r"""
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to the pooler_output.
+ """
+
+ image_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
+ """
+)
+class Siglip2TextOutput(ModelOutput):
+ r"""
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The text embeddings obtained by applying the projection layer to the pooler_output.
+ """
+
+ text_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring
+class Siglip2Output(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
+ Contrastive loss for image-text similarity.
+ logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
+ similarity scores.
+ logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
+ similarity scores.
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The text embeddings obtained by applying the projection layer to the pooled output of [`Siglip2TextModel`].
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The image embeddings obtained by applying the projection layer to the pooled output of [`Siglip2VisionModel`].
+ text_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`Siglip2TextModel`].
+ vision_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`Siglip2VisionModel`].
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits_per_image: Optional[torch.FloatTensor] = None
+ logits_per_text: Optional[torch.FloatTensor] = None
+ text_embeds: Optional[torch.FloatTensor] = None
+ image_embeds: Optional[torch.FloatTensor] = None
+ text_model_output: BaseModelOutputWithPooling = None
+ vision_model_output: BaseModelOutputWithPooling = None
+
+ def to_tuple(self) -> tuple[Any]:
+ return tuple(
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
+ for k in self.keys()
+ )
+
+
+class Siglip2VisionEmbeddings(nn.Module):
+ def __init__(self, config: Siglip2VisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Linear(
+ in_features=config.num_channels * self.patch_size * self.patch_size,
+ out_features=self.embed_dim,
+ )
+
+ self.num_patches = config.num_patches
+ self.position_embedding_size = int(self.num_patches**0.5)
+ self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
+
+ @staticmethod
+ def resize_positional_embeddings(
+ positional_embeddings: torch.Tensor,
+ spatial_shapes: torch.LongTensor,
+ max_length: int,
+ ) -> torch.Tensor:
+ """
+ Resize positional embeddings to image-specific size and pad to a fixed size.
+
+ Args:
+ positional_embeddings (`torch.Tensor`):
+ Position embeddings of shape (height, width, embed_dim)
+ spatial_shapes (`torch.LongTensor`):
+ Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
+ max_length (`int`):
+ Maximum length of the positional embeddings to pad resized positional embeddings to
+
+ Returns:
+ `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
+ """
+ batch_size = spatial_shapes.shape[0]
+ embed_dim = positional_embeddings.shape[-1]
+ source_dtype = positional_embeddings.dtype
+
+ resulted_positional_embeddings = torch.empty(
+ (batch_size, max_length, embed_dim),
+ device=positional_embeddings.device,
+ dtype=source_dtype,
+ )
+
+ # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
+ positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
+
+ # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
+ if positional_embeddings.device.type == "cpu":
+ positional_embeddings = positional_embeddings.to(torch.float32)
+
+ for i in range(batch_size):
+ # (1, dim, height, width) -> (1, dim, target_height, target_width)
+ height, width = spatial_shapes[i]
+ resized_embeddings = F.interpolate(
+ positional_embeddings,
+ size=(height, width),
+ mode="bilinear",
+ align_corners=False,
+ antialias=True,
+ )
+
+ # (1, dim, target_height, target_width) -> (target_height * target_width, dim)
+ resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
+
+ # Cast to original dtype
+ resized_embeddings = resized_embeddings.to(source_dtype)
+
+ resulted_positional_embeddings[i, : height * width] = resized_embeddings
+ resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
+
+ return resulted_positional_embeddings
+
+ def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor`):
+ Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
+ spatial_shapes (`list[tuple[int, int]]`):
+ Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
+ """
+
+ # Apply patch embeddings to already patchified pixel values
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
+
+ # Get positional resized and padded positional embeddings
+ positional_embeddings = self.position_embedding.weight.reshape(
+ self.position_embedding_size, self.position_embedding_size, -1
+ )
+ resized_positional_embeddings = self.resize_positional_embeddings(
+ positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
+ )
+
+ # Add positional embeddings to patch embeddings
+ embeddings = patch_embeds + resized_positional_embeddings
+ return embeddings
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Siglip2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+ self.is_causal = False
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, seq_length, embed_dim = hidden_states.shape
+
+ queries = self.q_proj(hidden_states)
+ keys = self.k_proj(hidden_states)
+ values = self.v_proj(hidden_states)
+
+ queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ queries,
+ keys,
+ values,
+ attention_mask,
+ is_causal=self.is_causal,
+ scaling=self.scale,
+ dropout=0.0 if not self.training else self.dropout,
+ )
+
+ attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class Siglip2MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Siglip2EncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.self_attn = Siglip2Attention(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = Siglip2MLP(config)
+
+ @auto_docstring
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class Siglip2Encoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`Siglip2EncoderLayer`].
+
+ Args:
+ config: Siglip2Config
+ """
+
+ def __init__(self, config: Siglip2Config):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ # Ignore copy
+ @auto_docstring
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutput:
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(
+ hidden_states,
+ attention_mask,
+ **kwargs,
+ )
+
+ return BaseModelOutput(last_hidden_state=hidden_states)
+
+
+class Siglip2VisionTransformer(nn.Module):
+ def __init__(self, config: Siglip2VisionConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = Siglip2VisionEmbeddings(config)
+ self.encoder = Siglip2Encoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
+ if self.use_head:
+ self.head = Siglip2MultiheadAttentionPoolingHead(config)
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ attention_mask: torch.Tensor,
+ spatial_shapes: torch.LongTensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ Tensor containing the spatial dimensions (height, width) of the input images.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ hidden_states = self.embeddings(pixel_values, spatial_shapes)
+
+ if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
+ else:
+ encoder_attention_mask = attention_mask
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooler_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+def _trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2,
+ )
+
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+
+
+def trunc_normal_tf_(
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
+) -> torch.Tensor:
+ """Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \\leq \text{mean} \\leq b`.
+
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
+ and the result is subsequently scaled and shifted by the mean and std args.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ """
+ with torch.no_grad():
+ _trunc_normal_(tensor, 0, 1.0, a, b)
+ tensor.mul_(std).add_(mean)
+
+
+def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ if mode == "fan_in":
+ denom = fan_in
+ elif mode == "fan_out":
+ denom = fan_out
+ elif mode == "fan_avg":
+ denom = (fan_in + fan_out) / 2
+
+ variance = scale / denom
+
+ if distribution == "truncated_normal":
+ # constant is stddev of standard normal truncated to (-2, 2)
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
+ elif distribution == "normal":
+ with torch.no_grad():
+ tensor.normal_(std=math.sqrt(variance))
+ elif distribution == "uniform":
+ bound = math.sqrt(3 * variance)
+ with torch.no_grad():
+ tensor.uniform_(-bound, bound)
+ else:
+ raise ValueError(f"invalid distribution {distribution}")
+
+
+def lecun_normal_(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
+
+
+def default_flax_embed_init(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
+
+
+@auto_docstring
+class Siglip2PreTrainedModel(PreTrainedModel):
+ config: Siglip2Config
+ base_model_prefix = "siglip2"
+ supports_gradient_checkpointing = True
+
+ _no_split_modules = [
+ "Siglip2TextEmbeddings",
+ "Siglip2VisionEmbeddings",
+ "Siglip2EncoderLayer",
+ "Siglip2MultiheadAttentionPoolingHead",
+ ]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+
+ _can_record_outputs = {
+ "hidden_states": Siglip2EncoderLayer,
+ "attentions": Siglip2Attention,
+ }
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, Siglip2VisionEmbeddings):
+ width = (
+ self.config.vision_config.hidden_size
+ if isinstance(self.config, Siglip2Config)
+ else self.config.hidden_size
+ )
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
+ elif isinstance(module, nn.Embedding):
+ default_flax_embed_init(module.weight)
+ elif isinstance(module, Siglip2Attention):
+ nn.init.xavier_uniform_(module.q_proj.weight)
+ nn.init.xavier_uniform_(module.k_proj.weight)
+ nn.init.xavier_uniform_(module.v_proj.weight)
+ nn.init.xavier_uniform_(module.out_proj.weight)
+ nn.init.zeros_(module.q_proj.bias)
+ nn.init.zeros_(module.k_proj.bias)
+ nn.init.zeros_(module.v_proj.bias)
+ nn.init.zeros_(module.out_proj.bias)
+ elif isinstance(module, Siglip2MLP):
+ nn.init.xavier_uniform_(module.fc1.weight)
+ nn.init.xavier_uniform_(module.fc2.weight)
+ nn.init.normal_(module.fc1.bias, std=1e-6)
+ nn.init.normal_(module.fc2.bias, std=1e-6)
+ elif isinstance(module, Siglip2MultiheadAttentionPoolingHead):
+ nn.init.xavier_uniform_(module.probe.data)
+ nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
+ nn.init.zeros_(module.attention.in_proj_bias.data)
+ elif isinstance(module, Siglip2Model):
+ logit_scale_init = torch.log(torch.tensor(1.0))
+ module.logit_scale.data.fill_(logit_scale_init)
+ module.logit_bias.data.zero_()
+ elif isinstance(module, Siglip2ForImageClassification):
+ nn.init.normal_(
+ module.classifier.weight,
+ std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
+ )
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
+ lecun_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+class Siglip2TextEmbeddings(nn.Module):
+ def __init__(self, config: Siglip2TextConfig):
+ super().__init__()
+ embed_dim = config.hidden_size
+
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+ max_position_embedding = self.position_embedding.weight.shape[0]
+
+ if seq_length > max_position_embedding:
+ raise ValueError(
+ f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
+ f"{seq_length} and max_position_embeddings: {max_position_embedding}"
+ )
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+
+ return embeddings
+
+
+class Siglip2TextTransformer(nn.Module):
+ def __init__(self, config: Siglip2TextConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+ self.embeddings = Siglip2TextEmbeddings(config)
+ self.encoder = Siglip2Encoder(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ self.head = nn.Linear(embed_dim, config.projection_size)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPooling:
+ if input_ids is None:
+ raise ValueError("You have to specify input_ids")
+
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
+
+ # note: Siglip2's text model does not use a causal mask, unlike the original CLIP model.
+ # expand attention_mask
+ uses_flash_attention = "flash" in self.config._attn_implementation
+ if uses_flash_attention:
+ attention_mask = None
+ elif attention_mask is not None and not uses_flash_attention:
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+ # The model uses the last token's hidden state, which may be padding.
+ pooled_output = last_hidden_state[:, -1, :]
+ pooled_output = self.head(pooled_output)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The text model from Siglip2 without any head or projection on top.
+ """
+)
+class Siglip2TextModel(Siglip2PreTrainedModel):
+ config: Siglip2TextConfig
+
+ def __init__(self, config: Siglip2TextConfig):
+ super().__init__(config)
+ self.text_model = Siglip2TextTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.text_model.embeddings.token_embedding
+
+ def set_input_embeddings(self, value):
+ self.text_model.embeddings.token_embedding = value
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Siglip2TextModel
+
+ >>> model = Siglip2TextModel.from_pretrained("google/siglip2-base-patch16-224")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224")
+
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ **kwargs,
+ )
+
+
+class Siglip2MultiheadAttentionPoolingHead(nn.Module):
+ """Multihead Attention Pooling."""
+
+ def __init__(self, config: Siglip2VisionConfig):
+ super().__init__()
+
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = Siglip2MLP(config)
+ self.num_heads = config.num_attention_heads
+
+ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ batch_size = hidden_state.shape[0]
+ probe = self.probe.repeat(batch_size, 1, 1)
+
+ if attention_mask is not None:
+ target_len, source_len = probe.shape[1], hidden_state.shape[1]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len)
+ attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
+ attention_mask = attention_mask.reshape(-1, target_len, source_len)
+
+ hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
+
+ residual = hidden_state
+ hidden_state = self.layernorm(hidden_state)
+ hidden_state = residual + self.mlp(hidden_state)
+
+ return hidden_state[:, 0]
+
+
+@auto_docstring(
+ custom_intro="""
+ The vision model from Siglip2 without any head or projection on top.
+ """
+)
+class Siglip2VisionModel(Siglip2PreTrainedModel):
+ config: Siglip2VisionConfig
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: Siglip2VisionConfig):
+ super().__init__(config)
+
+ self.vision_model = Siglip2VisionTransformer(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_attention_mask: torch.Tensor,
+ spatial_shapes: torch.LongTensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ Tensor containing the spatial dimensions (height, width) of the input images.
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Siglip2VisionModel
+
+ >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled features
+ ```"""
+ return self.vision_model(
+ pixel_values=pixel_values,
+ attention_mask=pixel_attention_mask,
+ spatial_shapes=spatial_shapes,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+
+@auto_docstring
+class Siglip2Model(Siglip2PreTrainedModel):
+ config: Siglip2Config
+
+ def __init__(self, config: Siglip2Config):
+ super().__init__(config)
+
+ if not isinstance(config.text_config, Siglip2TextConfig):
+ raise TypeError(
+ "config.text_config is expected to be of type Siglip2TextConfig but is of type"
+ f" {type(config.text_config)}."
+ )
+
+ if not isinstance(config.vision_config, Siglip2VisionConfig):
+ raise TypeError(
+ "config.vision_config is expected to be of type Siglip2VisionConfig but is of type"
+ f" {type(config.vision_config)}."
+ )
+
+ text_config = config.text_config
+ vision_config = config.vision_config
+
+ # First, initialize the text and vision models with proper attention implementation
+ text_model = Siglip2TextModel._from_config(text_config)
+ vision_model = Siglip2VisionModel._from_config(vision_config)
+
+ # Second, get the text and vision submodules (for backward compatibility)
+ self.text_model = text_model.text_model
+ self.vision_model = vision_model.vision_model
+
+ self.logit_scale = nn.Parameter(torch.randn(1))
+ self.logit_bias = nn.Parameter(torch.randn(1))
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @filter_out_non_signature_kwargs()
+ @auto_docstring
+ def get_text_features(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output of [`Siglip2TextModel`].
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModel
+ >>> import torch
+
+ >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224")
+
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
+ >>> with torch.no_grad():
+ ... text_features = model.get_text_features(**inputs)
+ ```"""
+ text_outputs: BaseModelOutputWithPooling = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ )
+ pooled_output = text_outputs.pooler_output
+
+ return pooled_output
+
+ @filter_out_non_signature_kwargs()
+ @auto_docstring
+ def get_image_features(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.Tensor] = None,
+ spatial_shapes: Optional[torch.LongTensor] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ Tensor containing the spatial dimensions (height, width) of the input images.
+
+ Returns:
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
+ applying the projection layer to the pooled output of [`Siglip2VisionModel`].
+
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> from transformers.image_utils import load_image
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = load_image(url)
+
+ >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... image_features = model.get_image_features(**inputs)
+ ```
+ """
+ vision_outputs: BaseModelOutputWithPooling = self.vision_model(
+ pixel_values=pixel_values,
+ attention_mask=pixel_attention_mask,
+ spatial_shapes=spatial_shapes,
+ )
+ pooled_output = vision_outputs.pooler_output
+
+ return pooled_output
+
+ # NOTE: Siglip2Model uses Pretrained backbones, so we don't need to add `check_model_inputs` here
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.Tensor] = None,
+ spatial_shapes: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ return_loss: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> Siglip2Output:
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ Tensor containing the spatial dimensions (height, width) of the input images.
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> import torch
+
+ >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
+ >>> # important: we pass `padding=max_length` since the model was trained with this
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> logits_per_image = outputs.logits_per_image
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
+ 31.9% that image 0 is 'a photo of 2 cats'
+ ```
+ """
+ # Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ vision_outputs: BaseModelOutputWithPooling = self.vision_model(
+ pixel_values=pixel_values,
+ attention_mask=pixel_attention_mask,
+ spatial_shapes=spatial_shapes,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ text_outputs: BaseModelOutputWithPooling = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ image_embeds = vision_outputs.pooler_output
+ text_embeds = text_outputs.pooler_output
+
+ # normalized features
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
+
+ logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
+ logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
+
+ logits_per_image = logits_per_text.t()
+
+ loss = None
+ if return_loss:
+ # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287
+ eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
+ m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
+ loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
+ nll = -torch.sum(loglik, dim=-1)
+ loss = nll.mean()
+
+ return Siglip2Output(
+ loss=loss,
+ logits_per_image=logits_per_image,
+ logits_per_text=logits_per_text,
+ text_embeds=text_embeds,
+ image_embeds=image_embeds,
+ text_model_output=text_outputs,
+ vision_model_output=vision_outputs,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Siglip2 vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
+ the patch tokens) e.g. for ImageNet.
+ """
+)
+class Siglip2ForImageClassification(Siglip2PreTrainedModel):
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: Siglip2Config) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+
+ # Create the vision model with proper attention
+ # and take only vision_model submodule (for backward compatibility)
+ vision_model = Siglip2VisionModel._from_config(config.vision_config)
+ self.vision_model = vision_model.vision_model
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_attention_mask: Optional[torch.Tensor] = None,
+ spatial_shapes: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> ImageClassifierOutput:
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ Tensor containing the spatial dimensions (height, width) of the input images.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, Siglip2ForImageClassification
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> # note: we are loading a `Siglip2Model` from the hub here,
+ >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
+ >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224")
+ >>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the two classes
+ >>> predicted_class_idx = logits.argmax(-1).item()
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
+ Predicted class: LABEL_1
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs: BaseModelOutputWithPooling = self.vision_model(
+ pixel_values,
+ attention_mask=pixel_attention_mask,
+ spatial_shapes=spatial_shapes,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ # average pool the patch tokens
+ if pixel_attention_mask is not None:
+ pool_mask = pixel_attention_mask[..., None].to(sequence_output.device)
+ sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1)
+ else:
+ sequence_output = torch.mean(sequence_output, dim=1)
+
+ # apply classifier
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config)
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "Siglip2Model",
+ "Siglip2PreTrainedModel",
+ "Siglip2TextModel",
+ "Siglip2VisionModel",
+ "Siglip2ForImageClassification",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/modular_siglip2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/modular_siglip2.py
new file mode 100644
index 0000000000000000000000000000000000000000..260a82e5143e157301f4d590270504e6205ae9ea
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/modular_siglip2.py
@@ -0,0 +1,605 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from transformers.models.siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
+from transformers.models.siglip.modeling_siglip import (
+ BaseModelOutput,
+ BaseModelOutputWithPooling,
+ ImageClassifierOutput,
+ SiglipForImageClassification,
+ SiglipModel,
+ SiglipMultiheadAttentionPoolingHead,
+ SiglipOutput,
+ SiglipPreTrainedModel,
+ SiglipTextModel,
+ SiglipTextModelOutput,
+ SiglipVisionModel,
+ SiglipVisionModelOutput,
+ SiglipVisionTransformer,
+)
+
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...utils import auto_docstring, filter_out_non_signature_kwargs
+
+
+class Siglip2TextConfig(SiglipTextConfig):
+ pass
+
+
+class Siglip2VisionConfig(SiglipVisionConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a
+ Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2
+ [google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input images.
+ num_patches (`int`, *optional*, defaults to 256):
+ The number of patches in the image with the size of (`patch_size`, `patch_size`).
+ The image is resized to fill maximum of this number of patches, and to preserve
+ the aspect ratio. In case the resulted number of patches is lower, the image is
+ padded in "patch" dimension.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+
+ Example:
+
+ ```python
+ >>> from transformers import Siglip2VisionConfig, Siglip2VisionModel
+
+ >>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration
+ >>> configuration = Siglip2VisionConfig()
+
+ >>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration
+ >>> model = Siglip2VisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ def __init__(
+ self,
+ hidden_size=768,
+ intermediate_size=3072,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ num_channels=3,
+ num_patches=256,
+ patch_size=16,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.num_patches = num_patches
+ del self.image_size
+
+
+class Siglip2Config(SiglipConfig):
+ pass
+
+
+class Siglip2VisionOutput(SiglipVisionModelOutput):
+ pass
+
+
+class Siglip2TextOutput(SiglipTextModelOutput):
+ pass
+
+
+class Siglip2Output(SiglipOutput):
+ pass
+
+
+class Siglip2VisionEmbeddings(nn.Module):
+ def __init__(self, config: Siglip2VisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Linear(
+ in_features=config.num_channels * self.patch_size * self.patch_size,
+ out_features=self.embed_dim,
+ )
+
+ self.num_patches = config.num_patches
+ self.position_embedding_size = int(self.num_patches**0.5)
+ self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
+
+ @staticmethod
+ def resize_positional_embeddings(
+ positional_embeddings: torch.Tensor,
+ spatial_shapes: torch.LongTensor,
+ max_length: int,
+ ) -> torch.Tensor:
+ """
+ Resize positional embeddings to image-specific size and pad to a fixed size.
+
+ Args:
+ positional_embeddings (`torch.Tensor`):
+ Position embeddings of shape (height, width, embed_dim)
+ spatial_shapes (`torch.LongTensor`):
+ Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
+ max_length (`int`):
+ Maximum length of the positional embeddings to pad resized positional embeddings to
+
+ Returns:
+ `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
+ """
+ batch_size = spatial_shapes.shape[0]
+ embed_dim = positional_embeddings.shape[-1]
+ source_dtype = positional_embeddings.dtype
+
+ resulted_positional_embeddings = torch.empty(
+ (batch_size, max_length, embed_dim),
+ device=positional_embeddings.device,
+ dtype=source_dtype,
+ )
+
+ # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
+ positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
+
+ # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
+ if positional_embeddings.device.type == "cpu":
+ positional_embeddings = positional_embeddings.to(torch.float32)
+
+ for i in range(batch_size):
+ # (1, dim, height, width) -> (1, dim, target_height, target_width)
+ height, width = spatial_shapes[i]
+ resized_embeddings = F.interpolate(
+ positional_embeddings,
+ size=(height, width),
+ mode="bilinear",
+ align_corners=False,
+ antialias=True,
+ )
+
+ # (1, dim, target_height, target_width) -> (target_height * target_width, dim)
+ resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
+
+ # Cast to original dtype
+ resized_embeddings = resized_embeddings.to(source_dtype)
+
+ resulted_positional_embeddings[i, : height * width] = resized_embeddings
+ resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
+
+ return resulted_positional_embeddings
+
+ def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor`):
+ Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
+ spatial_shapes (`list[tuple[int, int]]`):
+ Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
+ """
+
+ # Apply patch embeddings to already patchified pixel values
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
+
+ # Get positional resized and padded positional embeddings
+ positional_embeddings = self.position_embedding.weight.reshape(
+ self.position_embedding_size, self.position_embedding_size, -1
+ )
+ resized_positional_embeddings = self.resize_positional_embeddings(
+ positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
+ )
+
+ # Add positional embeddings to patch embeddings
+ embeddings = patch_embeds + resized_positional_embeddings
+ return embeddings
+
+
+class Siglip2VisionTransformer(SiglipVisionTransformer):
+ def __init__(self, config: Siglip2VisionConfig):
+ super().__init__(config)
+
+ # Update: add `spatial_shapes` and `attention_mask`
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ attention_mask: torch.Tensor,
+ spatial_shapes: torch.LongTensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ Tensor containing the spatial dimensions (height, width) of the input images.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ hidden_states = self.embeddings(pixel_values, spatial_shapes)
+
+ if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
+ else:
+ encoder_attention_mask = attention_mask
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooler_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class Siglip2PreTrainedModel(SiglipPreTrainedModel):
+ pass
+
+
+class Siglip2TextModel(SiglipTextModel):
+ pass
+
+
+class Siglip2MultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead):
+ def __init__(self, config: Siglip2VisionConfig):
+ super().__init__(config)
+ self.num_heads = config.num_attention_heads
+
+ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ batch_size = hidden_state.shape[0]
+ probe = self.probe.repeat(batch_size, 1, 1)
+
+ if attention_mask is not None:
+ target_len, source_len = probe.shape[1], hidden_state.shape[1]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len)
+ attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
+ attention_mask = attention_mask.reshape(-1, target_len, source_len)
+
+ hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
+
+ residual = hidden_state
+ hidden_state = self.layernorm(hidden_state)
+ hidden_state = residual + self.mlp(hidden_state)
+
+ return hidden_state[:, 0]
+
+
+class Siglip2VisionModel(SiglipVisionModel):
+ # Update: add `spatial_shapes` and `pixel_attention_mask`
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_attention_mask: torch.Tensor,
+ spatial_shapes: torch.LongTensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ Tensor containing the spatial dimensions (height, width) of the input images.
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Siglip2VisionModel
+
+ >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled features
+ ```"""
+ return self.vision_model(
+ pixel_values=pixel_values,
+ attention_mask=pixel_attention_mask,
+ spatial_shapes=spatial_shapes,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+
+class Siglip2Model(SiglipModel):
+ # Update: add `spatial_shapes` and `pixel_attention_mask`
+ @filter_out_non_signature_kwargs()
+ @auto_docstring
+ def get_image_features(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.Tensor] = None,
+ spatial_shapes: Optional[torch.LongTensor] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ Tensor containing the spatial dimensions (height, width) of the input images.
+
+ Returns:
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
+ applying the projection layer to the pooled output of [`Siglip2VisionModel`].
+
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> from transformers.image_utils import load_image
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = load_image(url)
+
+ >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... image_features = model.get_image_features(**inputs)
+ ```
+ """
+ vision_outputs: BaseModelOutputWithPooling = self.vision_model(
+ pixel_values=pixel_values,
+ attention_mask=pixel_attention_mask,
+ spatial_shapes=spatial_shapes,
+ )
+ pooled_output = vision_outputs.pooler_output
+
+ return pooled_output
+
+ # Update: add `spatial_shapes` and `pixel_attention_mask`
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.Tensor] = None,
+ spatial_shapes: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ return_loss: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> Siglip2Output:
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ Tensor containing the spatial dimensions (height, width) of the input images.
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> import torch
+
+ >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
+ >>> # important: we pass `padding=max_length` since the model was trained with this
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> logits_per_image = outputs.logits_per_image
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
+ 31.9% that image 0 is 'a photo of 2 cats'
+ ```
+ """
+ # Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ vision_outputs: BaseModelOutputWithPooling = self.vision_model(
+ pixel_values=pixel_values,
+ attention_mask=pixel_attention_mask,
+ spatial_shapes=spatial_shapes,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ text_outputs: BaseModelOutputWithPooling = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ image_embeds = vision_outputs.pooler_output
+ text_embeds = text_outputs.pooler_output
+
+ # normalized features
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
+
+ logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
+ logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
+
+ logits_per_image = logits_per_text.t()
+
+ loss = None
+ if return_loss:
+ # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287
+ eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
+ m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
+ loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
+ nll = -torch.sum(loglik, dim=-1)
+ loss = nll.mean()
+
+ return Siglip2Output(
+ loss=loss,
+ logits_per_image=logits_per_image,
+ logits_per_text=logits_per_text,
+ text_embeds=text_embeds,
+ image_embeds=image_embeds,
+ text_model_output=text_outputs,
+ vision_model_output=vision_outputs,
+ )
+
+
+class Siglip2ForImageClassification(SiglipForImageClassification):
+ # Update: add `spatial_shapes` and `pixel_attention_mask`
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_attention_mask: Optional[torch.Tensor] = None,
+ spatial_shapes: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> ImageClassifierOutput:
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ Tensor containing the spatial dimensions (height, width) of the input images.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, Siglip2ForImageClassification
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> # note: we are loading a `Siglip2Model` from the hub here,
+ >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
+ >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224")
+ >>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the two classes
+ >>> predicted_class_idx = logits.argmax(-1).item()
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
+ Predicted class: LABEL_1
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs: BaseModelOutputWithPooling = self.vision_model(
+ pixel_values,
+ attention_mask=pixel_attention_mask,
+ spatial_shapes=spatial_shapes,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ # average pool the patch tokens
+ if pixel_attention_mask is not None:
+ pool_mask = pixel_attention_mask[..., None].to(sequence_output.device)
+ sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1)
+ else:
+ sequence_output = torch.mean(sequence_output, dim=1)
+
+ # apply classifier
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config)
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "Siglip2Config",
+ "Siglip2TextConfig",
+ "Siglip2VisionConfig",
+ "Siglip2Model",
+ "Siglip2PreTrainedModel",
+ "Siglip2TextModel",
+ "Siglip2VisionModel",
+ "Siglip2ForImageClassification",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/processing_siglip2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/processing_siglip2.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e177b237b10633cc6181f07d54d24fbd1600d5c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/siglip2/processing_siglip2.py
@@ -0,0 +1,69 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Image/Text processor class for SigLIP2.
+"""
+
+from typing import Optional
+
+from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
+
+
+class Siglip2ImagesKwargs(ImagesKwargs, total=False):
+ max_num_patches: Optional[int]
+ patch_size: Optional[int]
+
+
+class Siglip2ProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs: Siglip2ImagesKwargs
+
+ _defaults = {
+ "text_kwargs": {
+ "padding": "max_length",
+ "truncation": True,
+ "max_length": 64,
+ },
+ "images_kwargs": {
+ "max_num_patches": 256,
+ "patch_size": 16,
+ },
+ }
+
+
+class Siglip2Processor(ProcessorMixin):
+ r"""
+ Constructs a Siglip2 processor which wraps a Siglip2 image processor and a Gemma tokenizer into a single processor.
+
+ [`Siglip2Processor`] offers all the functionalities of [`Siglip2ImageProcessor`] and [`GemmaTokenizerFast`]. See the
+ [`~Siglip2Processor.__call__`] and [`~Siglip2Processor.decode`] for more information.
+
+ Args:
+ image_processor ([`Siglip2ImageProcessor`]):
+ The image processor is a required input.
+ tokenizer ([`GemmaTokenizerFast`]):
+ The tokenizer is a required input.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+ valid_processor_kwargs = Siglip2ProcessorKwargs
+
+ def __init__(self, image_processor, tokenizer):
+ super().__init__(image_processor, tokenizer)
+
+
+__all__ = ["Siglip2Processor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5205a84b25c45f7f378d22808f1a307afb58c909
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_smolvlm import *
+ from .image_processing_smolvlm import *
+ from .image_processing_smolvlm_fast import *
+ from .modeling_smolvlm import *
+ from .processing_smolvlm import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/configuration_smolvlm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/configuration_smolvlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dca4721c02b858f0883bd328bc98f9eb1a98094
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/configuration_smolvlm.py
@@ -0,0 +1,196 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/smolvlm/modular_smolvlm.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_smolvlm.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
+# Written by Orr Zohar
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class SmolVLMVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`SmolVLMVisionModel`]. It is used to instantiate a
+ SmolVLM vision encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the SigLIP checkpoint
+ [google/siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) used in SmolVLM
+ [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 1152):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input images.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 32):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+ Example:
+
+ ```python
+ >>> from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
+ >>> from transformers.models.smolvlm.configuration_smolvlm import SmolVLMVisionConfig
+
+ >>> # Initializing a SmolVLMVisionConfig with google/siglip-so400m-patch14-384 style configuration
+ >>> configuration = SmolVLMVisionConfig()
+
+ >>> # Initializing a SmolVLMVisionTransformer (with random weights) from the google/siglip-so400m-patch14-384 style configuration
+ >>> model = SmolVLMVisionTransformer(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "smolvlm_vision"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=1152,
+ intermediate_size=3072,
+ num_hidden_layers=12,
+ num_attention_heads=16,
+ num_channels=3,
+ image_size=224,
+ patch_size=32,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+
+
+class SmolVLMConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a
+ SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM
+ [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should cache the key/value pairs of the attention mechanism. Only
+ relevant if `config.is_decoder=True`.
+ image_token_id (`int`, *optional*, defaults to 128257):
+ The id of the "image" token.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to tie the word embeddings with the token embeddings.
+ vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`):
+ Custom vision config or dict for the vision tower
+ text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`):
+ Custom text config or dict for the text model
+ scale_factor (`int`, *optional*, defaults to 2):
+ The scale factor for the image encoder.
+ pad_token_id (`int`, *optional*, defaults to 128002):
+ The id of the padding token.
+
+ Example:
+ ```python
+ >>> from transformers import SmolVLMModel, SmolVLMConfig
+ >>> # Initializing configuration
+ >>> configuration = SmolVLMConfig()
+ >>> # Initializing a model from the configuration
+ >>> model = SmolVLMModel(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "smolvlm"
+ sub_configs = {"text_config": AutoConfig, "vision_config": SmolVLMVisionConfig}
+
+ def __init__(
+ self,
+ use_cache=True,
+ image_token_id=128257,
+ tie_word_embeddings=False,
+ vision_config=None,
+ text_config=None,
+ scale_factor=2,
+ pad_token_id=128_002,
+ **kwargs,
+ ):
+ self.image_token_id = image_token_id
+ self.use_cache = use_cache
+ self.tie_word_embeddings = tie_word_embeddings
+
+ if vision_config is None:
+ self.vision_config = SmolVLMVisionConfig()
+ logger.info("vision_config is None, using default vision config")
+ elif isinstance(vision_config, dict):
+ self.vision_config = SmolVLMVisionConfig(**vision_config)
+ elif isinstance(vision_config, SmolVLMVisionConfig):
+ self.vision_config = vision_config
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "llama")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ logger.info("text_config is None, using default text config")
+ text_config = CONFIG_MAPPING["llama"](
+ rms_norm_eps=1e-5,
+ pad_token_id=pad_token_id,
+ tie_word_embeddings=False,
+ )
+
+ self.text_config = text_config
+ self.scale_factor = scale_factor
+
+ super().__init__(**kwargs, pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings)
+
+
+__all__ = ["SmolVLMVisionConfig", "SmolVLMConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/image_processing_smolvlm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/image_processing_smolvlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..c08339b817325ace5d4e37767d315a5b69d7164d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/image_processing_smolvlm.py
@@ -0,0 +1,901 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/smolvlm/modular_smolvlm.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_smolvlm.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
+# Written by Orr Zohar
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from collections.abc import Iterable
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature
+from ...image_transforms import PaddingMode, pad, to_channel_dimension_format, to_pil_image
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_nested_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, is_vision_available, logging
+
+
+if is_vision_available():
+ import PIL
+ from PIL import Image
+
+
+logger = logging.get_logger(__name__)
+MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum
+
+
+def _resize_output_size_rescale_to_max_len(
+ height: int, width: int, min_len: Optional[int] = 1, max_len: Optional[int] = None
+) -> tuple[int, int]:
+ """
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ min_len (`int`, *optional*, defaults to 1):
+ Minimum size of the output image.
+ max_len (`int`, *optional*, defaults to the maximum size of the image):
+ Maximum size of the output image.
+ Returns:
+ The output size of the image after resizing.
+ """
+ max_len = max(height, width) if max_len is None else max_len
+ aspect_ratio = width / height
+
+ if width >= height:
+ width = max_len
+ height = int(width / aspect_ratio)
+ if height % 2 != 0:
+ height += 1
+ elif height > width:
+ height = max_len
+ width = int(height * aspect_ratio)
+ if width % 2 != 0:
+ width += 1
+
+ # Avoid resizing to a size smaller than min_len
+ height = max(height, min_len)
+ width = max(width, min_len)
+ return height, width
+
+
+def _resize_output_size_scale_below_upper_bound(
+ height: int, width: int, max_len: Optional[dict[str, int]] = None
+) -> tuple[int, int]:
+ """
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ max_len (`dict[str, int]`, *optional*, defaults to the maximum size of the image):
+ Defines the maximum dimensions of the image.
+ Returns:
+ The output size of the image after resizing.
+ """
+ max_len = max(height, width) if max_len is None else max_len
+
+ aspect_ratio = width / height
+ if width >= height and width > max_len:
+ width = max_len
+ height = int(width / aspect_ratio)
+ elif height > width and height > max_len:
+ height = max_len
+ width = int(height * aspect_ratio)
+
+ # Avoid resizing to a size smaller than 1
+ height = max(height, 1)
+ width = max(width, 1)
+ return height, width
+
+
+def get_resize_output_image_size(
+ image,
+ resolution_max_side: int,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> tuple[int, int]:
+ """
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ resolution_max_side (`int`):
+ The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
+ input aspect ratio.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+ Returns:
+ The output size of the image after resizing.
+ """
+ height, width = get_image_size(image, channel_dim=input_data_format)
+
+ # Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
+ height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
+ # Find the output size when scaling the image to be below the MAX_IMAGE_SIZE
+ height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
+ return height, width
+
+
+def get_max_height_width(
+ images_list: list[list[np.ndarray]], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> list[int]:
+ """
+ Get the maximum height and width across all images in a batch.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4))
+
+ max_height = max_width = float("-inf")
+ for images in images_list:
+ for image in images:
+ height, width = get_image_size(image, channel_dim=input_data_format)
+ max_height = max(height, max_height)
+ max_width = max(width, max_width)
+ return (max_height, max_width)
+
+
+def make_pixel_mask(
+ image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> np.ndarray:
+ """
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+ Args:
+ image (`np.ndarray`):
+ Image to make the pixel mask for.
+ output_size (`tuple[int, int]`):
+ Output size of the mask.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ mask = np.zeros(output_size, dtype=np.int64)
+ mask[:input_height, :input_width] = 1
+ return mask
+
+
+def convert_to_rgb(
+ image: np.ndarray,
+ palette: Optional[PIL.ImagePalette.ImagePalette] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> ImageInput:
+ """
+ Converts an image to RGB format.
+ Args:
+ image (`np.ndarray`):
+ The image to convert.
+ palette (list[int], *optional*):
+ The palette to use if given.
+ data_format (ChannelDimension or str, *optional*):
+ The channel dimension format for the output image. If not provided, it will be the same as the input image.
+ input_data_format (ChannelDimension or str, *optional*):
+ The channel dimension format of the input image.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
+
+ # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
+ # The resized image from PIL will always have channels last, so find the input format first.
+ data_format = input_data_format if data_format is None else data_format
+
+ mode = "P" if palette is not None else None
+ image = to_pil_image(image, image_mode=mode, input_data_format=input_data_format)
+ if image.mode == "P" and palette is not None:
+ image.putpalette(palette)
+
+ image_rgba = image.convert("RGBA")
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
+ alpha_composite = Image.alpha_composite(background, image_rgba)
+ alpha_composite = alpha_composite.convert("RGB")
+
+ output_array = np.array(alpha_composite)
+ # The image is always in channels last format after converting from a PIL image
+ output_array = to_channel_dimension_format(output_array, data_format, input_channel_dim=ChannelDimension.LAST)
+ return output_array
+
+
+# FIXME Amy: make a more general crop function that isn't just centre crop
+def _crop(
+ image: np.ndarray,
+ w1: int,
+ h1: int,
+ w2: int,
+ h2: int,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> np.ndarray:
+ if data_format is None:
+ data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
+
+ if data_format == ChannelDimension.FIRST:
+ image = image[:, h1:h2, w1:w2]
+ elif data_format == ChannelDimension.LAST:
+ image = image[h1:h2, w1:w2, :]
+ else:
+ raise ValueError("Invalid channel dimension format.")
+
+ return image
+
+
+class SmolVLMImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a SmolVLM image processor.
+ Args:
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB. This is useful if the input image is of a different format e.g. RGBA.
+ Only has an effect if the input image is in the PIL format.
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image. The longest edge of the image is resized to be <= `size["longest_edge"]`, with the
+ shortest edge resized to keep the input aspect ratio.
+ size (`Dict`, *optional*, defaults to `{"longest_edge": 4 * 364}`):
+ Controls the size of the output image. This is a dictionary containing the key "longest_edge".
+ The image will be resized such that the longest edge is <= `size["longest_edge"]` and the shortest edge is resized
+ to keep the input aspect ratio.
+ resample (`Resampling`, *optional*, defaults to `Resampling.LANCZOS`):
+ Resampling filter to use when resizing the image.
+ do_image_splitting (`bool`, *optional*, defaults to `True`):
+ Whether to split the image into sub-images concatenated with the original image. They are split into patches
+ such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
+ max_image_size (`Dict`, *optional*, defaults to `{"longest_edge": 364}`):
+ Maximum resolution of the patches of images accepted by the model. This is a dictionary containing the key "longest_edge".
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image. If set to `True`, the image is rescaled to have pixel values between 0 and 1.
+ rescale_factor (`float`, *optional*, defaults to `1/255`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. If set to `True`, the image is normalized to have a mean of `image_mean` and
+ a standard deviation of `image_std`.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether or not to pad the images to the largest height and width in the batch and number of images per
+ sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width).
+ """
+
+ model_input_names = ["pixel_values", "pixel_attention_mask"]
+
+ def __init__(
+ self,
+ do_convert_rgb: bool = True,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.LANCZOS,
+ do_image_splitting: bool = True,
+ max_image_size: Optional[dict[str, int]] = None,
+ do_rescale: bool = True,
+ rescale_factor: float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.do_convert_rgb = do_convert_rgb
+ self.do_resize = do_resize
+ self.size = size if size is not None else {"longest_edge": 4 * 364}
+ self.resample = resample
+ self.do_image_splitting = do_image_splitting
+ self.max_image_size = max_image_size if max_image_size is not None else {"longest_edge": 364}
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+ self.do_pad = do_pad
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.LANCZOS,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image. The longest edge of the image is resized to size["longest_edge"], with the shortest edge
+ resized to keep the input aspect ratio. Can also be used with size["height"] and size["width"].
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
+ Resampling filter to use when resizing the image.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the output image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
+
+ # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
+ # The resized image from PIL will always have channels last, so find the input format first.
+ data_format = input_data_format if data_format is None else data_format
+
+ if "longest_edge" in size:
+ size = get_resize_output_image_size(
+ image, resolution_max_side=size["longest_edge"], input_data_format=input_data_format
+ )
+ elif "height" in size and "width" in size:
+ size = (size["height"], size["width"])
+ else:
+ raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
+
+ image_mode = None
+ if image.ndim == 2 or image.shape[-1] == 1:
+ image_mode = "P"
+ image = to_pil_image(image, image_mode=image_mode, input_data_format=input_data_format)
+
+ resized_image = image.resize((size[1], size[0]), resample=resample)
+ resized_image = np.array(resized_image)
+
+ # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
+ # so we need to add it back if necessary.
+ resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
+ # The image is always in channels last format after converting from a PIL image
+ resized_image = to_channel_dimension_format(
+ resized_image, data_format, input_channel_dim=ChannelDimension.LAST
+ )
+ return resized_image
+
+ def split_image(
+ self,
+ image,
+ max_image_size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.LANCZOS,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Split an image into squares of side max_image_size and the original image resized to max_image_size.
+ That means that a single image becomes a sequence of images.
+ This is a "trick" to spend more compute on each image with no changes in the vision encoder.
+ 1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio.
+ 2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)`
+ sub-images of the same size each (image_size, image_size). Typically, 364x364.
+ 3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width.
+ Args:
+ image (`np.ndarray`):
+ Images to split.
+ max_image_size (`dict[str, int]`):
+ Maximum size of the output image. If the image is larger than this size, it will be split into
+ patches of this size, and the original image will be concatenated with the patches, resized to max_size.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
+ Resampling filter to use when resizing the image.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the output image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ height, width = get_image_size(image, channel_dim=input_data_format)
+ max_height = max_width = max_image_size["longest_edge"]
+
+ frames = []
+ if height > max_height or width > max_width:
+ # Calculate the number of splits
+ num_splits_h = math.ceil(height / max_height)
+ num_splits_w = math.ceil(width / max_width)
+ # Calculate the optimal width and height for the sub-images
+ optimal_height = math.ceil(height / num_splits_h)
+ optimal_width = math.ceil(width / num_splits_w)
+
+ # Iterate through each row and column
+ for r in range(num_splits_h):
+ for c in range(num_splits_w):
+ # Calculate the starting point of the crop
+ start_x = c * optimal_width
+ start_y = r * optimal_height
+
+ # Calculate the ending point of the crop
+ end_x = min(start_x + optimal_width, width)
+ end_y = min(start_y + optimal_height, height)
+
+ # Crop the image
+ cropped_image = _crop(
+ image,
+ start_x,
+ start_y,
+ end_x,
+ end_y,
+ data_format=data_format,
+ )
+ frames.append(cropped_image)
+
+ # For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency
+ global_image_height, global_image_width = max_height, max_width
+ if height != global_image_height or width != global_image_width:
+ image = self.resize(
+ image,
+ {"height": global_image_height, "width": global_image_width},
+ resample=resample,
+ input_data_format=data_format,
+ )
+ else:
+ num_splits_h, num_splits_w = 0, 0
+
+ frames.append(image)
+
+ return frames, num_splits_h, num_splits_w
+
+ def resize_for_vision_encoder(
+ self,
+ image: np.ndarray,
+ vision_encoder_max_size: int,
+ resample: PILImageResampling = PILImageResampling.LANCZOS,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Resize images to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
+ Args:
+ image (`np.ndarray`):
+ Images to resize.
+ vision_encoder_max_size (`int`):
+ Maximum size of the output image. If the image is larger than this size, it will be split into
+ patches of this size, and the original image will be concatenated with the patches, resized to max_size.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
+ Resampling filter to use when resizing the image.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the output image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred
+ """
+ height, width = get_image_size(image, channel_dim=input_data_format)
+
+ aspect_ratio = width / height
+ if width >= height:
+ width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
+ height = int(width / aspect_ratio)
+ height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
+ elif height > width:
+ height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
+ width = int(height * aspect_ratio)
+ width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
+ new_size = {"height": height, "width": width}
+ return self.resize(
+ image, size=new_size, resample=resample, input_data_format=input_data_format, data_format=data_format
+ )
+
+ def _pad_image(
+ self,
+ image: np.ndarray,
+ output_size: tuple[int, int],
+ constant_values: Union[float, Iterable[float]] = 0,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pad an image with zeros to the given size.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ output_height, output_width = output_size
+
+ pad_bottom = output_height - input_height
+ pad_right = output_width - input_width
+ padding = ((0, pad_bottom), (0, pad_right))
+ padded_image = pad(
+ image,
+ padding,
+ mode=PaddingMode.CONSTANT,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ return padded_image
+
+ def pad(
+ self,
+ images: list[list[np.ndarray]],
+ constant_values: Union[float, Iterable[float]] = 0,
+ return_pixel_mask: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> BatchFeature:
+ """
+ For a list of images, for each images, pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width.
+ For each sample in the batch, pads the sample with empty images to the max_number of images per sample in the batch. Optionally returns a pixel mask.
+ Args:
+ images (`list[list[np.ndarray]]`):
+ List of list of images to pad. Pads to the largest height and width in the batch.
+ constant_values (`float` or `Iterable[float]`, *optional*):
+ The value to use for the padding if `mode` is `"constant"`.
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
+ Whether to return a pixel mask.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ pad_size = get_max_height_width(images, input_data_format=input_data_format)
+
+ batch_size = len(images)
+ max_num_images = max(len(images_) for images_ in images)
+ input_data_format = (
+ infer_channel_dimension_format(images[0][0], num_channels=(1, 3, 4))
+ if input_data_format is None
+ else input_data_format
+ )
+ data_format = input_data_format if data_format is None else data_format
+ # filter out empty image lists, then take first image of the first sample
+ first_image_in_list = [sample_images for sample_images in images if sample_images][0][0]
+
+ if input_data_format == ChannelDimension.FIRST:
+ n_channels = first_image_in_list.shape[0]
+ elif input_data_format == ChannelDimension.LAST:
+ n_channels = first_image_in_list.shape[-1]
+ else:
+ raise ValueError("Invalid channel dimension format.")
+
+ def empty_image(size, input_data_format):
+ if input_data_format == ChannelDimension.FIRST:
+ return np.zeros((n_channels, *size), dtype=np.uint8)
+ elif input_data_format == ChannelDimension.LAST:
+ return np.zeros((*size, n_channels), dtype=np.uint8)
+
+ padded_images_list = [
+ [empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size)
+ ]
+ padded_masks = [[np.zeros(pad_size, dtype=np.int64) for _ in range(max_num_images)] for _ in range(batch_size)]
+
+ for batch_idx in range(batch_size):
+ for sample_idx, image in enumerate(images[batch_idx]):
+ padded_images_list[batch_idx][sample_idx] = self._pad_image(
+ image,
+ pad_size,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ padded_masks[batch_idx][sample_idx] = make_pixel_mask(
+ image, output_size=pad_size, input_data_format=input_data_format
+ )
+
+ padded_masks = padded_masks if return_pixel_mask else None
+ return padded_images_list, padded_masks
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_convert_rgb: Optional[bool] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_image_splitting: Optional[bool] = None,
+ do_rescale: Optional[bool] = None,
+ max_image_size: Optional[dict[str, int]] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_row_col_info: bool = False,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Preprocess a batch of images.
+ Args:
+ images (`ImageInput`):
+ A list of images to preprocess.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. With the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_image_splitting (`bool`, *optional*, defaults to `self.do_image_splitting`):
+ Whether to split the image into sub-images concatenated with the original image. They are split into patches
+ such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
+ max_image_size (`Dict`, *optional*, defaults to `self.max_image_size`):
+ Maximum resolution of the images. If the image is larger than this size, the image is split into patches.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether or not to pad the images to the largest height and width in the batch.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ return_row_col_info (`bool`, *optional*, default to `False`):
+ Whether to return the number of rows and columns of the split images. This is used for the
+ `SmolVLMProcessor` to generate prompt strings based on the number of rows and columns.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_image_splitting = do_image_splitting if do_image_splitting is not None else self.do_image_splitting
+ max_image_size = max_image_size if max_image_size is not None else self.max_image_size
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+ do_pad = do_pad if do_pad is not None else self.do_pad
+
+ images = self.fetch_images(images)
+ images_list = make_nested_list_of_images(images)
+
+ if not valid_images(images_list[0]):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ # save the palettes for conversion to RGB
+ palettes_list = [
+ [im.getpalette() if isinstance(im, Image.Image) and im.mode == "P" else None for im in images]
+ for images in images_list
+ ]
+
+ # All transformations expect numpy arrays.
+ images_list = [[to_numpy_array(image) for image in images] for images in images_list]
+ # Search for the first image in the image list.
+ # NOTE: we can't slice the first image with images_list[0][0] if the first batch contains no images. See #36682
+ first_image_in_list = [images for images in images_list if images][0][0]
+
+ # Extra channel dimension for grayscale images
+ if input_data_format in [ChannelDimension.LAST, None]:
+ images_list = [
+ [np.expand_dims(img, axis=-1) if img.ndim == 2 else img for img in images] for images in images_list
+ ]
+ elif input_data_format == ChannelDimension.FIRST:
+ images_list = [
+ [np.expand_dims(img, axis=0) if img.ndim == 2 else img for img in images] for images in images_list
+ ]
+
+ if do_rescale and is_scaled_image(first_image_in_list):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ # We assume that all images have the same channel dimension format.
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(first_image_in_list, num_channels=(1, 3, 4))
+
+ if do_resize:
+ images_list = [
+ [
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+ for image in images
+ ]
+ for images in images_list
+ ]
+
+ if do_image_splitting:
+ # We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
+ # for size=(10, max_image_size) -> rescaled_size=(max_image_size, max_image_size)
+ # for size=(11, max_image_size+1) -> rescaled_size=(max_image_size, max_image_size*2)
+ images_list = [
+ [
+ self.resize_for_vision_encoder(
+ image, max_image_size["longest_edge"], resample=resample, input_data_format=input_data_format
+ )
+ for image in images
+ ]
+ for images in images_list
+ ]
+ images_list_split_arrays = []
+ palettes_list_split_arrays = []
+ images_list_rows = []
+ images_list_cols = []
+ for images, palettes in zip(images_list, palettes_list):
+ split_image_arrays = []
+ split_palettes_arrays = []
+ image_rows = []
+ image_cols = []
+ for image, palette in zip(images, palettes):
+ split_image_array, rows, cols = self.split_image(
+ image,
+ max_image_size=max_image_size,
+ resample=resample,
+ input_data_format=input_data_format,
+ )
+ split_image_arrays.extend(split_image_array)
+ split_palettes_arrays.extend([palette] * len(split_image_array))
+ image_rows.append(rows)
+ image_cols.append(cols)
+ images_list_split_arrays.append(split_image_arrays)
+ palettes_list_split_arrays.append(split_palettes_arrays)
+ images_list_rows.append(image_rows)
+ images_list_cols.append(image_cols)
+ images_list = images_list_split_arrays
+ palettes_list = palettes_list_split_arrays
+ else:
+ # We square the images to max_image_size
+ images_list = [
+ [
+ self.resize(
+ image=image,
+ size={"height": max_image_size["longest_edge"], "width": max_image_size["longest_edge"]},
+ resample=resample,
+ input_data_format=input_data_format,
+ )
+ for image in images
+ ]
+ for images in images_list
+ ]
+ images_list_rows = [[0] * len(images) for images in images_list]
+ images_list_cols = [[0] * len(images) for images in images_list]
+
+ if do_convert_rgb:
+ images_list = [
+ [convert_to_rgb(img, palette) for img, palette in zip(images, palettes)]
+ for images, palettes in zip(images_list, palettes_list)
+ ]
+
+ if do_rescale:
+ images_list = [
+ [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
+ for images in images_list
+ ]
+
+ if do_normalize:
+ images_list = [
+ [
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+ for images in images_list
+ ]
+
+ pixel_attention_mask = None
+ if do_pad:
+ images_list, pixel_attention_mask = self.pad(
+ images_list, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=input_data_format
+ )
+
+ if data_format is not None:
+ images_list = [
+ [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ for image in images
+ ]
+ for images in images_list
+ ]
+
+ # Faster tensor conversion
+ data = {"pixel_values": np.array(images_list) if do_pad and return_tensors is not None else images_list}
+ if pixel_attention_mask is not None:
+ data["pixel_attention_mask"] = (
+ np.array(pixel_attention_mask) if do_pad and return_tensors is not None else pixel_attention_mask
+ )
+
+ encoding = BatchFeature(data=data, tensor_type=return_tensors)
+
+ # This is needed for generating correct text inputs in the processor - we don't pad to the max number of images
+ if return_row_col_info:
+ encoding["rows"] = images_list_rows
+ encoding["cols"] = images_list_cols
+
+ return encoding
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number of image patches for a given image size.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*)
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of patches per image.
+ """
+ do_image_splitting = images_kwargs.get("do_image_splitting", self.do_image_splitting)
+ max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
+ size = images_kwargs.get("size", self.size)
+
+ num_patches = num_rows = num_cols = 1
+ if do_image_splitting:
+ height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=size["longest_edge"])
+ height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=4096)
+ aspect_ratio = width / height
+
+ if width >= height:
+ resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+ resized_height = int(width / aspect_ratio)
+ resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+ elif height > width:
+ resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+ resized_width = int(height * aspect_ratio)
+ resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+
+ max_height = max_width = max_image_size["longest_edge"]
+ if resized_height > max_height or resized_width > max_width:
+ # Calculate the number of splits
+ num_rows = math.ceil(resized_height / max_height)
+ num_cols = math.ceil(resized_width / max_width)
+ num_patches = num_rows * num_cols + 1
+
+ return num_patches, num_rows, num_cols
+
+
+__all__ = ["SmolVLMImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/image_processing_smolvlm_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/image_processing_smolvlm_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e24bc2795437cfa3d4c41a0d024e5f4ac82d681
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/image_processing_smolvlm_fast.py
@@ -0,0 +1,532 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/smolvlm/modular_smolvlm.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_smolvlm.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
+# Written by Orr Zohar
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Optional, Union
+
+import torch
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ SizeDict,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ImageInput,
+ PILImageResampling,
+ make_nested_list_of_images,
+)
+from ...processing_utils import Unpack
+from ...utils import TensorType, auto_docstring, is_torchvision_available, logging
+
+
+if is_torchvision_available():
+ from torchvision.transforms import functional as F
+
+
+logger = logging.get_logger(__name__)
+
+
+class SmolVLMFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ do_image_splitting (`bool`, *optional*, defaults to `True`):
+ Whether to split the image into sub-images concatenated with the original image. They are split into patches
+ such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
+ max_image_size (`Dict`, *optional*, defaults to `{"longest_edge": 364}`):
+ Maximum resolution of the patches of images accepted by the model. This is a dictionary containing the key "longest_edge".
+ return_row_col_info (`bool`, *optional*, defaults to `False`):
+ Whether to return the row and column information of the images.
+ """
+
+ do_image_splitting: Optional[bool]
+ max_image_size: Optional[dict[str, int]]
+ return_row_col_info: Optional[bool]
+
+
+MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum
+
+
+def _resize_output_size_rescale_to_max_len(
+ height: int, width: int, min_len: Optional[int] = 1, max_len: Optional[int] = None
+) -> tuple[int, int]:
+ """
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ min_len (`int`, *optional*, defaults to 1):
+ Minimum size of the output image.
+ max_len (`int`, *optional*, defaults to the maximum size of the image):
+ Maximum size of the output image.
+ Returns:
+ The output size of the image after resizing.
+ """
+ max_len = max(height, width) if max_len is None else max_len
+ aspect_ratio = width / height
+
+ if width >= height:
+ width = max_len
+ height = int(width / aspect_ratio)
+ if height % 2 != 0:
+ height += 1
+ elif height > width:
+ height = max_len
+ width = int(height * aspect_ratio)
+ if width % 2 != 0:
+ width += 1
+
+ # Avoid resizing to a size smaller than min_len
+ height = max(height, min_len)
+ width = max(width, min_len)
+ return height, width
+
+
+def _resize_output_size_scale_below_upper_bound(
+ height: int, width: int, max_len: Optional[dict[str, int]] = None
+) -> tuple[int, int]:
+ """
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
+ Defines the maximum dimensions of the image.
+ Returns:
+ The output size of the image after resizing.
+ """
+ max_len = max(height, width) if max_len is None else max_len
+
+ aspect_ratio = width / height
+ if width >= height and width > max_len:
+ width = max_len
+ height = int(width / aspect_ratio)
+ elif height > width and height > max_len:
+ height = max_len
+ width = int(height * aspect_ratio)
+
+ # Avoid resizing to a size smaller than 1
+ height = max(height, 1)
+ width = max(width, 1)
+ return height, width
+
+
+def get_resize_output_image_size(
+ image,
+ resolution_max_side: int,
+) -> tuple[int, int]:
+ """
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
+ Args:
+ image (`torch.Tensor`):
+ Image to resize.
+ resolution_max_side (`int`):
+ The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
+ input aspect ratio.
+ Returns:
+ The output size of the image after resizing.
+ """
+ height, width = image.size()[-2:]
+
+ # Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
+ height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
+ # Find the output size when scaling the image to be below the MAX_IMAGE_SIZE
+ height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
+ return height, width
+
+
+def get_max_height_width(images_list: list[list["torch.Tensor"]]) -> tuple[int, int]:
+ """
+ Get the maximum height and width across all images in a batch.
+ """
+ image_sizes = []
+ for images in images_list:
+ for image in images:
+ image_sizes.append(image.size()[-2:])
+
+ max_height = max(size[0] for size in image_sizes)
+ max_width = max(size[1] for size in image_sizes)
+ return (max_height, max_width)
+
+
+@auto_docstring
+class SmolVLMImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.LANCZOS
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ size = {"longest_edge": 4 * 364}
+ max_image_size = {"longest_edge": 364}
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+ do_image_splitting = True
+ do_pad = True
+ return_row_col_info = False
+ valid_kwargs = SmolVLMFastImageProcessorKwargs
+
+ def _prepare_images_structure(self, images: ImageInput, expected_ndims: int = 3) -> ImageInput:
+ """
+ Prepare a nested images structure for processing.
+ """
+ return make_nested_list_of_images(images, expected_ndims=expected_ndims)
+
+ def resize(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ antialias: bool = True,
+ **kwargs,
+ ) -> "torch.Tensor":
+ """
+ Resize an image. The longest edge of the image is resized to size.longest_edge, with the shortest edge
+ resized to keep the input aspect ratio. Can also be used with size.height and size.width.
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
+ antialias (`bool`, *optional*, defaults to `True`):
+ Whether to use antialiasing when resizing the image.
+ """
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
+ if interpolation == F.InterpolationMode.LANCZOS:
+ logger.warning_once(
+ "You have used fast image processor with LANCZOS resample which not yet supported for torch.Tensor. "
+ "BICUBIC resample will be used as an alternative. Please fall back to slow image processor if you "
+ "want full consistency with the original model."
+ )
+ interpolation = F.InterpolationMode.BICUBIC
+
+ if size.longest_edge:
+ size = get_resize_output_image_size(image, resolution_max_side=size.longest_edge)
+ elif size.height and size.width:
+ size = (size.height, size.width)
+ else:
+ raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
+
+ return F.resize(image, size, interpolation=interpolation, antialias=antialias)
+
+ def split_images(
+ self,
+ images: torch.Tensor,
+ max_image_size: dict[str, int],
+ interpolation: Optional["F.InterpolationMode"] = None,
+ ):
+ """
+ Split an image into squares of side max_image_size and the original image resized to max_image_size.
+ That means that a single image becomes a sequence of images.
+ This is a "trick" to spend more compute on each image with no changes in the vision encoder.
+ 1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio.
+ 2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)`
+ sub-images of the same size each (image_size, image_size). Typically, 364x364.
+ 3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width.
+ Args:
+ images (`torch.Tensor`):
+ Images to split.
+ max_image_size (`Dict[str, int]`):
+ Maximum size of the output image. If the image is larger than this size, it will be split into
+ patches of this size, and the original image will be concatenated with the patches, resized to max_size.
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
+ """
+ batch_size, num_channels, height, width = images.size()
+ height_dim, width_dim = 2, 3
+
+ max_height = max_width = max_image_size["longest_edge"]
+
+ frames = []
+ if height > max_height or width > max_width:
+ # Calculate the number of splits
+ num_splits_h = math.ceil(height / max_height)
+ num_splits_w = math.ceil(width / max_width)
+
+ # Split the images by height, then by width
+ frames = (
+ images.unfold(height_dim, size=max_height, step=max_height)
+ .unfold(width_dim, size=max_width, step=max_width)
+ .contiguous()
+ .view(batch_size, num_channels, -1, max_height, max_width)
+ .permute(0, 2, 1, 3, 4)
+ ) # batch_size x n_frames x num_channels x height x width
+
+ # For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency
+ global_image_height, global_image_width = max_height, max_width
+ images = self.resize(
+ images, SizeDict(height=global_image_height, width=global_image_width), interpolation=interpolation
+ )
+
+ frames = torch.cat((frames, images.unsqueeze(1)), dim=1)
+ else:
+ num_splits_h, num_splits_w = 0, 0
+ frames = images.unsqueeze(1)
+
+ num_splits_h = [num_splits_h] * batch_size
+ num_splits_w = [num_splits_w] * batch_size
+
+ return frames, num_splits_h, num_splits_w
+
+ def resize_for_vision_encoder(
+ self,
+ image: torch.Tensor,
+ vision_encoder_max_size: int,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ ):
+ """
+ Resize images to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
+ Args:
+ image (`torch.Tensor`):
+ Images to resize.
+ vision_encoder_max_size (`int`):
+ Maximum size of the output image. If the image is larger than this size, it will be split into
+ patches of this size, and the original image will be concatenated with the patches, resized to max_size.
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
+ """
+ height, width = image.size()[-2:]
+
+ aspect_ratio = width / height
+ if width >= height:
+ width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
+ height = int(width / aspect_ratio)
+ height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
+ elif height > width:
+ height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
+ width = int(height * aspect_ratio)
+ width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
+ new_size = SizeDict(height=height, width=width)
+ return self.resize(image, size=new_size, interpolation=interpolation)
+
+ def pad(
+ self,
+ image: torch.Tensor,
+ padded_size: tuple[int, int],
+ fill: int = 0,
+ return_pixel_mask: bool = True,
+ ):
+ original_size = image.shape[-2:]
+ padding_bottom = padded_size[0] - original_size[0]
+ padding_right = padded_size[1] - original_size[1]
+
+ if padding_bottom < 0 or padding_right < 0:
+ raise ValueError(
+ f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
+ f"original size. Got padded size: {padded_size}, original size: {original_size}."
+ )
+
+ # Only pad if necessary
+ if original_size != padded_size:
+ padding = (0, 0, padding_right, padding_bottom)
+ image = F.pad(image, padding, fill=fill, padding_mode="constant")
+
+ # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+ pixel_mask = None
+ if return_pixel_mask:
+ pixel_mask = torch.zeros_like(image[..., 0, :, :], dtype=torch.int64)
+ pixel_mask[: original_size[0], : original_size[1]] = 1
+
+ return image, pixel_mask
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[SmolVLMFastImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def _preprocess(
+ self,
+ images: list[list["torch.Tensor"]],
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ do_pad: Optional[bool],
+ do_image_splitting: Optional[bool],
+ max_image_size: Optional[dict[str, int]],
+ return_row_col_info: Optional[bool],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Process a batch of images for the model.
+ """
+
+ grouped_images, grouped_images_index = group_images_by_shape(
+ images, is_nested=True, disable_grouping=disable_grouping
+ )
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(stacked_images, size, interpolation=interpolation)
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index, is_nested=True)
+
+ grouped_images, grouped_images_index = group_images_by_shape(
+ resized_images, is_nested=True, disable_grouping=disable_grouping
+ )
+ split_images_grouped = {}
+ if do_image_splitting:
+ rows_grouped = {}
+ cols_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ stacked_images = self.resize_for_vision_encoder(
+ stacked_images, max_image_size["longest_edge"], interpolation=interpolation
+ )
+ stacked_images, rows, cols = self.split_images(
+ stacked_images, max_image_size=max_image_size, interpolation=interpolation
+ )
+ split_images_grouped[shape] = stacked_images
+ rows_grouped[shape] = rows
+ cols_grouped[shape] = cols
+ processed_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
+ rows = reorder_images(rows_grouped, grouped_images_index, is_nested=True)
+ cols = reorder_images(cols_grouped, grouped_images_index, is_nested=True)
+ # flattenened the doubly nested list to a nested list
+ for i, group_images in enumerate(processed_images):
+ processed_images[i] = [image for sublist in group_images for image in sublist]
+ else:
+ for shape, stacked_images in grouped_images.items():
+ # We square the images to max_image_size
+ stacked_images = self.resize(
+ image=stacked_images,
+ size=SizeDict(height=max_image_size["longest_edge"], width=max_image_size["longest_edge"]),
+ interpolation=interpolation,
+ )
+ split_images_grouped[shape] = stacked_images
+ processed_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
+ rows = [[0] * len(images) for images in processed_images]
+ cols = [[0] * len(images) for images in processed_images]
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(
+ processed_images, is_nested=True, disable_grouping=disable_grouping
+ )
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=True)
+ if do_pad:
+ # Get max images per batch
+ max_num_images = max(len(images_) for images_ in processed_images)
+ max_height, max_width = get_max_height_width(processed_images)
+
+ processed_images_padded = torch.zeros(
+ len(processed_images),
+ max_num_images,
+ *(processed_images[0][0].shape[0], max_height, max_width),
+ device=processed_images[0][0].device,
+ )
+ pixel_attention_masks = torch.zeros(
+ len(processed_images),
+ max_num_images,
+ *(max_height, max_width),
+ device=processed_images[0][0].device,
+ )
+ for i, images in enumerate(processed_images):
+ for j, image in enumerate(images):
+ processed_images_padded[i, j], pixel_attention_masks[i, j] = self.pad(
+ image, (max_height, max_width)
+ )
+ processed_images = processed_images_padded
+
+ if do_pad:
+ data = {"pixel_values": processed_images, "pixel_attention_mask": pixel_attention_masks}
+ elif return_tensors == "pt":
+ data = {"pixel_values": torch.stack([torch.stack(images) for images in processed_images])}
+ else:
+ data = {"pixel_values": processed_images}
+ # This is needed for generating correct text inputs in the processor - we don't pad to the max number of images
+ encoding = BatchFeature(data=data, tensor_type=return_tensors)
+
+ if return_row_col_info:
+ encoding["rows"] = rows
+ encoding["cols"] = cols
+
+ return encoding
+
+ def to_dict(self):
+ encoder_dict = super().to_dict()
+ encoder_dict.pop("_valid_processor_keys", None)
+ encoder_dict.pop("return_row_col_info", None)
+ return encoder_dict
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number of image patches for a given image size.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*)
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of patches per image.
+ """
+ do_image_splitting = images_kwargs.get("do_image_splitting", self.do_image_splitting)
+ max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
+ size = images_kwargs.get("size", self.size)
+
+ num_patches = num_rows = num_cols = 1
+ if do_image_splitting:
+ height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=size["longest_edge"])
+ height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
+ aspect_ratio = width / height
+
+ if width >= height:
+ resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+ resized_height = int(width / aspect_ratio)
+ resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+ elif height > width:
+ resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+ resized_width = int(height * aspect_ratio)
+ resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+
+ max_height = max_width = max_image_size["longest_edge"]
+ if resized_height > max_height or resized_width > max_width:
+ # Calculate the number of splits
+ num_rows = math.ceil(resized_height / max_height)
+ num_cols = math.ceil(resized_width / max_width)
+ num_patches = num_rows * num_cols + 1
+
+ return num_patches, num_rows, num_cols
+
+
+__all__ = ["SmolVLMImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/modeling_smolvlm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/modeling_smolvlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ff2b041dd2decf4fbe45aa5da489e85fe2e61b6
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/modeling_smolvlm.py
@@ -0,0 +1,957 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/smolvlm/modular_smolvlm.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_smolvlm.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
+# Written by Orr Zohar
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationConfig, GenerationMixin
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, ModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ TransformersKwargs,
+ auto_docstring,
+ can_return_tuple,
+ logging,
+)
+from ...utils.generic import check_model_inputs
+from ..auto import AutoModel
+from .configuration_smolvlm import SmolVLMConfig, SmolVLMVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class SmolVLMRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ SmolVLMRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+@auto_docstring
+class SmolVLMPreTrainedModel(PreTrainedModel):
+ config: SmolVLMConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["SmolVLMVisionAttention", "SmolVLMDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
+
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+ elif isinstance(module, SmolVLMRMSNorm):
+ module.weight.data.fill_(1.0)
+
+
+class SmolVLMVisionEmbeddings(nn.Module):
+ """
+ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
+ resolution.
+
+ The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://huggingface.co/papers/2307.06304)
+ which allows treating images in their native aspect ratio and without the need to resize them to the same
+ fixed size. In particular, we start from the original pre-trained SigLIP model
+ (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
+ """
+
+ def __init__(self, config: SmolVLMVisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+
+ self.num_patches_per_side = self.image_size // self.patch_size
+ self.num_patches = self.num_patches_per_side**2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
+ boundaries = torch.arange(
+ 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
+ )
+ position_ids = torch.full(
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
+ )
+
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
+ nb_patches_h = p_attn_mask[:, 0].sum()
+ nb_patches_w = p_attn_mask[0].sum()
+
+ h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype)
+ w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype)
+
+ fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
+ fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
+
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
+
+ pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
+ position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
+
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class SmolVLMVisionAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ # Ignore copy
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, seq_length, embed_dim = hidden_states.shape
+
+ queries = self.q_proj(hidden_states)
+ keys = self.k_proj(hidden_states)
+ values = self.v_proj(hidden_states)
+
+ queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ queries,
+ keys,
+ values,
+ attention_mask,
+ is_causal=self.is_causal,
+ scaling=self.scale,
+ dropout=0.0 if not self.training else self.dropout,
+ )
+
+ attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class SmolVLMVisionMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class SmolVLMEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: SmolVLMVisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = SmolVLMVisionAttention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = SmolVLMVisionMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ @auto_docstring
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class SmolVLMEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`SmolVLMEncoderLayer`].
+
+ Args:
+ config: SmolVLMConfig
+ """
+
+ def __init__(self, config: SmolVLMConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([SmolVLMEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ # Ignore copy
+ @auto_docstring
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ )
+
+ hidden_states = layer_outputs
+
+ return BaseModelOutput(last_hidden_state=hidden_states)
+
+
+@auto_docstring(
+ custom_intro="""
+ The SmolVLM Vision Transformer Model outputting raw image embedding.
+ """
+)
+class SmolVLMVisionTransformer(SmolVLMPreTrainedModel):
+ config: SmolVLMVisionConfig
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _can_record_outputs = {
+ "hidden_states": SmolVLMEncoderLayer,
+ "attentions": SmolVLMVisionAttention,
+ }
+
+ def __init__(self, config: SmolVLMVisionConfig):
+ super().__init__(config)
+ embed_dim = config.hidden_size
+
+ self.embeddings = SmolVLMVisionEmbeddings(config)
+ self.encoder = SmolVLMEncoder(config)
+ self.patch_size = config.patch_size
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings = value
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ def forward(
+ self,
+ pixel_values,
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutput]:
+ batch_size = pixel_values.size(0)
+ if patch_attention_mask is None:
+ patch_size = self.patch_size
+ patch_attention_mask = torch.ones(
+ (
+ batch_size,
+ pixel_values.size(2) // patch_size,
+ pixel_values.size(3) // patch_size,
+ )
+ )
+ patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
+
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
+
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
+ if self.config._attn_implementation != "flash_attention_2":
+ patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
+ elif not torch.any(~patch_attention_mask):
+ patch_attention_mask = None
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=patch_attention_mask,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ return BaseModelOutput(
+ last_hidden_state=last_hidden_state,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding).
+ """
+)
+class SmolVLMBaseModelOutputWithPast(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+ sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+class SmolVLMSimpleMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ input_size = config.vision_config.hidden_size * (config.scale_factor**2)
+ output_size = config.text_config.hidden_size
+ self.proj = nn.Linear(input_size, output_size, bias=False)
+
+ def forward(self, x):
+ return self.proj(x)
+
+
+class SmolVLMConnector(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.scale_factor = config.scale_factor
+ self.modality_projection = SmolVLMSimpleMLP(config)
+
+ def pixel_shuffle(self, x, scale_factor=2):
+ bsz, seq, embed_dim = x.size()
+ height = width = int(seq**0.5)
+ x = x.view(bsz, height, width, embed_dim)
+ x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
+ x = x.permute(0, 2, 1, 3)
+ x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
+ x = x.permute(0, 2, 1, 3)
+ x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
+ return x
+
+ def forward(self, image_hidden_states):
+ image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
+ image_hidden_states = self.modality_projection(image_hidden_states)
+ return image_hidden_states
+
+
+@auto_docstring(
+ custom_intro="""
+ SmolVLM model consisting of a SIGLIP vision encoder and Llama3 language decoder
+ """
+)
+class SmolVLMModel(SmolVLMPreTrainedModel):
+ """
+ A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger
+ in forward. Instead, we override inputs_merger here with custom logic.
+ """
+
+ def __init__(self, config: SmolVLMConfig):
+ super().__init__(config)
+ self.padding_idx = self.config.text_config.pad_token_id
+ self.vocab_size = self.config.text_config.vocab_size
+
+ self.vision_model = SmolVLMVisionTransformer._from_config(config.vision_config)
+ self.connector = SmolVLMConnector(config)
+ self.text_model = AutoModel.from_config(config.text_config)
+
+ self.image_seq_len = int(
+ ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
+ )
+ self.image_token_id = self.config.image_token_id
+
+ self.post_init()
+
+ def enable_input_require_grads(self):
+ """
+ Enables the gradients for the input embeddings.
+
+ This is useful for lora when using gradient checkpointing.
+ c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
+
+ Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
+ """
+
+ def get_lowest_module(module):
+ if len(list(module.children())) == 0:
+ # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
+ return module
+ else:
+ # Recursively call the function on each child module
+ return get_lowest_module(list(module.children())[0])
+
+ def make_inputs_require_grads(module, input, output):
+ output.requires_grad_(True)
+
+ self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
+ self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
+ make_inputs_require_grads
+ )
+
+ def disable_input_require_grads(self):
+ self._text_require_grads_hook.remove()
+ self._vision_require_grads_hook.remove()
+
+ def get_input_embeddings(self):
+ return self.text_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.text_model.set_input_embeddings(value)
+
+ def inputs_merger(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor
+ ):
+ """
+ This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
+ The merging happens as follows:
+ - The text token sequence is: `tok_1 tok_2 tok_3 ... tok_4`.
+ - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space.
+ We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
+ - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
+ - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
+ """
+ _, patch_size, _ = image_hidden_states.shape
+
+ if input_ids is None:
+ image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ image_mask = image_mask[..., 0] # slice off the hidden dim
+ else:
+ image_mask = input_ids == self.config.image_token_id
+
+ num_image_tokens = image_mask.sum(dim=1)
+ if not torch.all(num_image_tokens % patch_size == 0):
+ raise ValueError("At least one sample has tokens not divisible by patch_size.")
+
+ blocks_per_sample = num_image_tokens // patch_size
+
+ offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0)
+ block_offset = offsets[:-1]
+ row_cum = image_mask.cumsum(dim=-1)
+ chunk_idx = (row_cum - 1) // patch_size
+ local_idx = (row_cum - 1) % patch_size
+ block_idx = block_offset.unsqueeze(1) + chunk_idx
+
+ image_embeds = torch.zeros_like(inputs_embeds)
+ image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :]
+
+ merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
+ return merged_embeds
+
+ def get_image_features(
+ self, pixel_values: torch.FloatTensor, pixel_attention_mask: Optional[torch.LongTensor] = None
+ ):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ pixel_attention_mask (`torch.LongTensor`, *optional*):
+ The attention mask indicating padded regions in the image.
+ """
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
+ pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
+
+ # Remove padding images - padding images are full 0.
+ nb_values_per_image = pixel_values.shape[1:].numel()
+ real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
+
+ if not any(real_images_inds):
+ # no images, leave one empty image.
+ real_images_inds[0] = True
+
+ pixel_values = pixel_values[real_images_inds].contiguous()
+ # Handle the vision attention mask
+ if pixel_attention_mask is None:
+ pixel_attention_mask = torch.ones(
+ size=[pixel_values.shape[i] for i in (0, 2, 3)],
+ dtype=torch.bool,
+ device=pixel_values.device,
+ )
+ else:
+ # Remove padding images from the mask
+ pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
+ pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
+ patch_size = self.config.vision_config.patch_size
+ patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
+ patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
+
+ # Get sequence from the vision encoder
+ image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
+ image_hidden_states = image_hidden_states.last_hidden_state
+
+ # Modality projection & resampling
+ image_hidden_states = self.connector(image_hidden_states)
+ return image_hidden_states
+
+ @can_return_tuple
+ @auto_docstring(
+ custom_intro="""
+ Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
+ the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
+ max_num_images is the maximum number of images among the batch_size samples in the batch.
+ Padding images are not needed beyond padding the pixel_values at the entrance of the model.
+ For efficiency, we only pass through the vision_model's forward the real images by
+ discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
+ image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
+ """
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
+ image_hidden_states: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, SmolVLMBaseModelOutputWithPast]:
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The hidden states of the image encoder after modality projection.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.training and self.text_model.gradient_checkpointing and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
+
+ # START VISUAL INPUTS INTEGRATION
+ if pixel_values is not None and image_hidden_states is not None:
+ raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
+
+ if pixel_values is not None:
+ image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(inputs_embeds.device)
+ elif image_hidden_states is not None:
+ image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=inputs_embeds.device)
+
+ if image_hidden_states is not None:
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
+ # that simply don't exist
+ inputs_embeds = self.inputs_merger(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ image_hidden_states=image_hidden_states,
+ )
+
+ outputs = self.text_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return SmolVLMBaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_hidden_states,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Idefics causal language model (or autoregressive) outputs.
+ """
+)
+class SmolVLMCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+ sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ The SmolVLM Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top.
+ """
+)
+class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = SmolVLMModel(config)
+ self.image_token_id = self.config.image_token_id
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.vocab_size = config.text_config.vocab_size
+ self.model.text_model.generation_config = GenerationConfig.from_model_config(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def enable_input_require_grads(self):
+ """
+ Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
+ the model weights fixed.
+ """
+
+ def make_inputs_require_grads(module, input, output):
+ output.requires_grad_(True)
+
+ self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
+ self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook(
+ make_inputs_require_grads
+ )
+
+ def disable_input_require_grads(self):
+ self._text_require_grads_hook.remove()
+ self._vision_require_grads_hook.remove()
+
+ def get_input_embeddings(self):
+ return self.model.text_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.text_model.set_input_embeddings(value)
+
+ def get_image_features(
+ self, pixel_values: torch.FloatTensor, pixel_attention_mask: Optional[torch.LongTensor] = None
+ ):
+ return self.model.get_image_features(pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
+ image_hidden_states: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ return_dict: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, SmolVLMCausalLMOutputWithPast]:
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The hidden states of the image encoder after modality projection.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> import requests
+ >>> import torch
+ >>> from PIL import Image
+ >>> from io import BytesIO
+
+ >>> from transformers import AutoProcessor, AutoModelForImageTextToText
+ >>> from transformers.image_utils import load_image
+
+ >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
+ >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
+ >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
+ >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
+
+ >>> processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
+ >>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", dtype=torch.bfloat16, device_map="auto")
+
+ >>> # Create inputs
+ >>> messages = [
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {"type": "video", "path": path/to/video},
+ ... {"type": "text", "text": "What is happening in this video?"},
+ ... ]
+ ... }
+ ... ]
+
+ >>> inputs = processor.apply_chat_template([messages], add_generation_prompt=True)
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
+ >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
+
+ >>> print(generated_texts)
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ pixel_values=pixel_values,
+ pixel_attention_mask=pixel_attention_mask,
+ image_hidden_states=image_hidden_states,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ return_dict=True,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return SmolVLMCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ pixel_values=None,
+ pixel_attention_mask=None,
+ image_hidden_states=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
+ # precedence is moved to the model, we can remove this fn)
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ pixel_values=pixel_values,
+ pixel_attention_mask=pixel_attention_mask,
+ image_hidden_states=image_hidden_states,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if image_hidden_states is not None or cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_attention_mask"] = None
+
+ return model_inputs
+
+
+__all__ = ["SmolVLMForConditionalGeneration", "SmolVLMPreTrainedModel", "SmolVLMModel", "SmolVLMVisionTransformer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/modular_smolvlm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/modular_smolvlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffc7f06c97c909a1e5a6d227ecf7983a2e99f1c1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/modular_smolvlm.py
@@ -0,0 +1,411 @@
+# coding=utf-8
+# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
+# Written by Orr Zohar
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationConfig
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, can_return_tuple, logging
+from ..idefics3.configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
+from ..idefics3.image_processing_idefics3 import Idefics3ImageProcessor
+from ..idefics3.image_processing_idefics3_fast import Idefics3ImageProcessorFast
+from ..idefics3.modeling_idefics3 import (
+ Idefics3BaseModelOutputWithPast,
+ Idefics3ForConditionalGeneration,
+ Idefics3Model,
+ Idefics3PreTrainedModel,
+ Idefics3VisionTransformer,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class SmolVLMVisionConfig(Idefics3VisionConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`SmolVLMVisionModel`]. It is used to instantiate a
+ SmolVLM vision encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the SigLIP checkpoint
+ [google/siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) used in SmolVLM
+ [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 1152):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input images.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 32):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+ Example:
+
+ ```python
+ >>> from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
+ >>> from transformers.models.smolvlm.configuration_smolvlm import SmolVLMVisionConfig
+
+ >>> # Initializing a SmolVLMVisionConfig with google/siglip-so400m-patch14-384 style configuration
+ >>> configuration = SmolVLMVisionConfig()
+
+ >>> # Initializing a SmolVLMVisionTransformer (with random weights) from the google/siglip-so400m-patch14-384 style configuration
+ >>> model = SmolVLMVisionTransformer(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "smolvlm_vision"
+ pass
+
+
+class SmolVLMPreTrainedModel(Idefics3PreTrainedModel):
+ pass
+
+
+class SmolVLMVisionTransformer(Idefics3VisionTransformer):
+ pass
+
+
+class SmolVLMConfig(Idefics3Config):
+ r"""
+ This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a
+ SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM
+ [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should cache the key/value pairs of the attention mechanism. Only
+ relevant if `config.is_decoder=True`.
+ image_token_id (`int`, *optional*, defaults to 128257):
+ The id of the "image" token.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to tie the word embeddings with the token embeddings.
+ vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`):
+ Custom vision config or dict for the vision tower
+ text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`):
+ Custom text config or dict for the text model
+ scale_factor (`int`, *optional*, defaults to 2):
+ The scale factor for the image encoder.
+ pad_token_id (`int`, *optional*, defaults to 128002):
+ The id of the padding token.
+
+ Example:
+ ```python
+ >>> from transformers import SmolVLMModel, SmolVLMConfig
+ >>> # Initializing configuration
+ >>> configuration = SmolVLMConfig()
+ >>> # Initializing a model from the configuration
+ >>> model = SmolVLMModel(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "smolvlm"
+ pass
+
+
+class SmolVLMImageProcessor(Idefics3ImageProcessor):
+ pass
+
+
+class SmolVLMImageProcessorFast(Idefics3ImageProcessorFast):
+ pass
+
+
+class SmolVLMBaseModelOutputWithPast(Idefics3BaseModelOutputWithPast):
+ pass
+
+
+class SmolVLMModel(Idefics3Model):
+ """
+ A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger
+ in forward. Instead, we override inputs_merger here with custom logic.
+ """
+
+ def inputs_merger(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor
+ ):
+ _, patch_size, _ = image_hidden_states.shape
+
+ if input_ids is None:
+ image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ image_mask = image_mask[..., 0] # slice off the hidden dim
+ else:
+ image_mask = input_ids == self.config.image_token_id
+
+ num_image_tokens = image_mask.sum(dim=1)
+ if not torch.all(num_image_tokens % patch_size == 0):
+ raise ValueError("At least one sample has tokens not divisible by patch_size.")
+
+ blocks_per_sample = num_image_tokens // patch_size
+
+ offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0)
+ block_offset = offsets[:-1]
+ row_cum = image_mask.cumsum(dim=-1)
+ chunk_idx = (row_cum - 1) // patch_size
+ local_idx = (row_cum - 1) % patch_size
+ block_idx = block_offset.unsqueeze(1) + chunk_idx
+
+ image_embeds = torch.zeros_like(inputs_embeds)
+ image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :]
+
+ merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
+ return merged_embeds
+
+ def get_image_features(
+ self, pixel_values: torch.FloatTensor, pixel_attention_mask: Optional[torch.LongTensor] = None
+ ):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ pixel_attention_mask (`torch.LongTensor`, *optional*):
+ The attention mask indicating padded regions in the image.
+ """
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
+ pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
+
+ # Remove padding images - padding images are full 0.
+ nb_values_per_image = pixel_values.shape[1:].numel()
+ real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
+
+ if not any(real_images_inds):
+ # no images, leave one empty image.
+ real_images_inds[0] = True
+
+ pixel_values = pixel_values[real_images_inds].contiguous()
+ # Handle the vision attention mask
+ if pixel_attention_mask is None:
+ pixel_attention_mask = torch.ones(
+ size=[pixel_values.shape[i] for i in (0, 2, 3)],
+ dtype=torch.bool,
+ device=pixel_values.device,
+ )
+ else:
+ # Remove padding images from the mask
+ pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
+ pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
+ patch_size = self.config.vision_config.patch_size
+ patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
+ patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
+
+ # Get sequence from the vision encoder
+ image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
+ image_hidden_states = image_hidden_states.last_hidden_state
+
+ # Modality projection & resampling
+ image_hidden_states = self.connector(image_hidden_states)
+ return image_hidden_states
+
+ @can_return_tuple
+ @auto_docstring(
+ custom_intro="""
+ Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
+ the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
+ max_num_images is the maximum number of images among the batch_size samples in the batch.
+ Padding images are not needed beyond padding the pixel_values at the entrance of the model.
+ For efficiency, we only pass through the vision_model's forward the real images by
+ discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
+ image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
+ """
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
+ image_hidden_states: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, SmolVLMBaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.training and self.text_model.gradient_checkpointing and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
+
+ # START VISUAL INPUTS INTEGRATION
+ if pixel_values is not None and image_hidden_states is not None:
+ raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
+
+ if pixel_values is not None:
+ image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(inputs_embeds.device)
+ elif image_hidden_states is not None:
+ image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=inputs_embeds.device)
+
+ if image_hidden_states is not None:
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
+ # that simply don't exist
+ inputs_embeds = self.inputs_merger(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ image_hidden_states=image_hidden_states,
+ )
+
+ outputs = self.text_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return SmolVLMBaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_hidden_states,
+ )
+
+
+class SmolVLMForConditionalGeneration(Idefics3ForConditionalGeneration):
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = SmolVLMModel(config)
+ self.model.text_model.generation_config = GenerationConfig.from_model_config(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def forward(self, **super_kwargs):
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The hidden states of the image encoder after modality projection.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> import requests
+ >>> import torch
+ >>> from PIL import Image
+ >>> from io import BytesIO
+
+ >>> from transformers import AutoProcessor, AutoModelForImageTextToText
+ >>> from transformers.image_utils import load_image
+
+ >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
+ >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
+ >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
+ >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
+
+ >>> processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
+ >>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", dtype=torch.bfloat16, device_map="auto")
+
+ >>> # Create inputs
+ >>> messages = [
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {"type": "video", "path": path/to/video},
+ ... {"type": "text", "text": "What is happening in this video?"},
+ ... ]
+ ... }
+ ... ]
+
+ >>> inputs = processor.apply_chat_template([messages], add_generation_prompt=True)
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
+ >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
+
+ >>> print(generated_texts)
+ ```"""
+ super().forward(**super_kwargs)
+
+
+__all__ = [
+ "SmolVLMVisionConfig",
+ "SmolVLMConfig",
+ "SmolVLMImageProcessor",
+ "SmolVLMImageProcessorFast",
+ "SmolVLMForConditionalGeneration",
+ "SmolVLMPreTrainedModel",
+ "SmolVLMModel",
+ "SmolVLMVisionTransformer",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/processing_smolvlm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/processing_smolvlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..42dcecce6a3b023cb8d0e3ba8f5ad8280f5f890d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/smolvlm/processing_smolvlm.py
@@ -0,0 +1,423 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for SmolVLM.
+"""
+
+from datetime import timedelta
+from typing import TYPE_CHECKING, Optional, Union
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput, make_nested_list_of_images
+from ...processing_utils import AllKwargsForChatTemplate, ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import BatchEncoding, TextInput
+from ...utils import is_num2words_available, is_vision_available, logging
+from ...video_utils import VideoInput
+
+
+if is_vision_available():
+ from .video_processing_smolvlm import (
+ DEFAULT_MEDIA_OUTTRO,
+ DEFAULT_VIDEO_INTRO,
+ FRAME_TIMESTAMP_MESSAGE,
+ )
+
+if is_vision_available():
+ from .video_processing_smolvlm import (
+ DEFAULT_MEDIA_OUTTRO,
+ DEFAULT_VIDEO_INTRO,
+ FRAME_TIMESTAMP_MESSAGE,
+ )
+
+if TYPE_CHECKING:
+ from ...tokenization_utils_base import PreTokenizedInput
+
+logger = logging.get_logger(__name__)
+
+
+if is_num2words_available():
+ from num2words import num2words
+else:
+ num2words = None
+
+
+# The correct chat template to be used for videos after #38105
+DEFAULT_CHAT_TEMPLATE = "<|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '' }}{% elif line['type'] == 'video' %}{{ '